from torch.nn import ModuleList import torch import torch.nn as nn from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell import time class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): super(DGCRM, self).__init__() assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' self.node_num = node_num self.input_dim = dim_in self.num_layers = num_layers self.DGCRM_cells = nn.ModuleList() self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) for _ in range(1, num_layers): self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) def forward(self, x, init_state, node_embeddings): assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim seq_length = x.shape[1] current_inputs = x output_hidden = [] for i in range(self.num_layers): state = init_state[i] inner_states = [] for t in range(seq_length): state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) inner_states.append(state) output_hidden.append(state) current_inputs = torch.stack(inner_states, dim=1) return current_inputs, output_hidden def init_hidden(self, batch_size): init_states = [] for i in range(self.num_layers): init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) # Build you torch or tf model class here class FedDGCN(nn.Module): def __init__(self, args): super(FedDGCN, self).__init__() # print("You are in subminigraph") self.num_node = args.minigraph_size self.input_dim = args.input_dim self.hidden_dim = args.rnn_units self.output_dim = args.output_dim self.horizon = args.horizon self.num_layers = args.num_layers self.use_D = args.use_day self.use_W = args.use_week self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 self.dropout2 = nn.Dropout(p=args.dropout) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim)) self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim)) # Initialize parameters nn.init.xavier_uniform_(self.node_embeddings1) nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.D_i_W_emb) self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order, args.embed_dim, args.num_layers) self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order, args.embed_dim, args.num_layers) # predictor self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) def forward(self, source): node_embedding1 = self.node_embeddings1 if self.use_D: t_i_d_data = source[..., 1] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) if self.use_W: d_i_w_data = source[..., 2] D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) node_embeddings=[node_embedding1,self.node_embeddings1] source = source[..., 0].unsqueeze(-1) init_state1 = self.encoder1.init_hidden(source.shape[0]) output, _ = self.encoder1(source, init_state1, node_embeddings) output = self.dropout1(output[:, -1:, :, :]) output1 = self.end_conv1(output) source1 = self.end_conv2(output) source2 = source - source1 init_state2 = self.encoder2.init_hidden(source2.shape[0]) output2, _ = self.encoder2(source2, init_state2, node_embeddings) output2 = self.dropout2(output2[:, -1:, :, :]) output2 = self.end_conv3(output2) return output1 + output2 class FederatedFedDGCN(nn.Module): def __init__(self, args): super(FederatedFedDGCN, self).__init__() # Initializing with None, we will populate model_list during the forward pass self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_list = None self.graph_num = (args.num_nodes + args.minigraph_size - 1) // args.minigraph_size self.args = args self.model_list = ModuleList(FedDGCN(self.args).to(self.device) for _ in range(self.graph_num)) def forward(self, source): """ Forward pass for the federated model. Each subgraph processes its portion of the data, and then the results are aggregated. Arguments: - source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims) Returns: - Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims) """ self.subgraph_num = source.shape[2] # Initialize a list to store the outputs of each subgraph model subgraph_outputs = [] # Iterate through the subgraph models # Parallel computation has not been realized yet, so it may slower than normal. for i in range(self.subgraph_num): # Extract the subgraph-specific data subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims) # Forward pass for each subgraph model subgraph_output = self.model_list[i](subgraph_data) subgraph_outputs.append(subgraph_output) # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) self.local_aggregate() return output_tensor def local_aggregate(self): """ Update the parameters of each model in model_list to the average of all models' parameters. """ with torch.no_grad(): # Ensure no gradients are calculated during the update # Iterate over each model in model_list for i, model in enumerate(self.model_list): # Iterate over each model's parameters for name, param in model.named_parameters(): # Initialize a container for the average value avg_param = torch.zeros_like(param) # Accumulate the corresponding parameters from all other models for other_model in self.model_list: avg_param += other_model.state_dict()[name] # Calculate the average avg_param /= len(self.model_list) # Update the current model's parameter param.data.copy_(avg_param)