diff --git a/federatedscope/trafficflow/model/FedDGCNv2.py b/federatedscope/trafficflow/model/FedDGCNv2.py index f565131..1b92e5d 100644 --- a/federatedscope/trafficflow/model/FedDGCNv2.py +++ b/federatedscope/trafficflow/model/FedDGCNv2.py @@ -3,6 +3,7 @@ import torch import torch.nn as nn from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell + class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): super(DGCRM, self).__init__() @@ -24,7 +25,8 @@ class DGCRM(nn.Module): state = init_state[i] inner_states = [] for t in range(seq_length): - state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) + state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, + [node_embeddings[0][:, t, :, :], node_embeddings[1]]) inner_states.append(state) output_hidden.append(state) current_inputs = torch.stack(inner_states, dim=1) @@ -34,7 +36,8 @@ class DGCRM(nn.Module): init_states = [] for i in range(self.num_layers): init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) - return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) + return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim) + # Build you torch or tf model class here class FedDGCN(nn.Module): @@ -49,7 +52,7 @@ class FedDGCN(nn.Module): self.num_layers = args.num_layers self.use_D = args.use_day self.use_W = args.use_week - self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 + self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 self.dropout2 = nn.Dropout(p=args.dropout) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) @@ -72,16 +75,16 @@ class FedDGCN(nn.Module): def forward(self, source): node_embedding1 = self.node_embeddings1 if self.use_D: - t_i_d_data = source[..., 1] + t_i_d_data = source[..., 1] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) if self.use_W: - d_i_w_data = source[..., 2] + d_i_w_data = source[..., 2] D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) - node_embeddings=[node_embedding1,self.node_embeddings1] + node_embeddings = [node_embedding1, self.node_embeddings1] source = source[..., 0].unsqueeze(-1) @@ -107,15 +110,11 @@ class FederatedFedDGCN(nn.Module): super(FederatedFedDGCN, self).__init__() # Initializing with None, we will populate model_list during the forward pass + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model_list = None - self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation) + self.graph_num = (args.num_nodes + args.minigraph_size - 1) // args.minigraph_size self.args = args - self.subgraph_num = 0 - self.num_node = args.minigraph_size - self.input_dim = args.input_dim - self.hidden_dim = args.rnn_units - self.output_dim = args.output_dim - self.horizon = args.horizon + self.model_list = ModuleList(FedDGCN(self.args).to(self.device) for _ in range(self.graph_num)) def forward(self, source): """ @@ -130,11 +129,6 @@ class FederatedFedDGCN(nn.Module): """ self.subgraph_num = source.shape[2] - # Initialize model_list if it hasn't been initialized yet - if self.model_list is None: - # Initialize model_list with FedDGCN models, one for each subgraph - self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)]) - # Initialize a list to store the outputs of each subgraph model subgraph_outputs = [] @@ -150,7 +144,7 @@ class FederatedFedDGCN(nn.Module): # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) - self.update_main_model() + # self.update_main_model() return output_tensor