From 0b41f04d3c5a3573b160862f2f0da33c40975fcb Mon Sep 17 00:00:00 2001 From: HengZhang Date: Thu, 28 Nov 2024 10:35:31 +0800 Subject: [PATCH] finish v2 model output --- .../core/auxiliaries/model_builder.py | 4 +- federatedscope/trafficflow/model/FedDGCNv2.py | 59 ++++++++++++++++++- 2 files changed, 58 insertions(+), 5 deletions(-) diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 22bfa39..5d94516 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -209,8 +209,8 @@ def get_model(model_config, local_data=None, backend='torch'): from federatedscope.trafficflow.model.FedDGCN import FedDGCN model = FedDGCN(model_config) else: - from federatedscope.trafficflow.model.FedDGCNv2 import FedDGCN - model = FedDGCN(model_config) + from federatedscope.trafficflow.model.FedDGCNv2 import FederatedFedDGCN + model = FederatedFedDGCN(model_config) else: raise ValueError('Model {} is not provided'.format(model_config.type)) diff --git a/federatedscope/trafficflow/model/FedDGCNv2.py b/federatedscope/trafficflow/model/FedDGCNv2.py index 927a509..ddd2f92 100644 --- a/federatedscope/trafficflow/model/FedDGCNv2.py +++ b/federatedscope/trafficflow/model/FedDGCNv2.py @@ -1,3 +1,5 @@ +from torch.nn import ModuleList + from federatedscope.register import register_model import torch import torch.nn as nn @@ -41,7 +43,7 @@ class FedDGCN(nn.Module): def __init__(self, args): super(FedDGCN, self).__init__() # print("You are in subminigraph") - self.num_node = args.num_nodes + self.num_node = args.minigraph_size self.input_dim = args.input_dim self.hidden_dim = args.rnn_units self.output_dim = args.output_dim @@ -60,9 +62,9 @@ class FedDGCN(nn.Module): nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.D_i_W_emb) - self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, + self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order, args.embed_dim, args.num_layers) - self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, + self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order, args.embed_dim, args.num_layers) # predictor self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) @@ -100,3 +102,54 @@ class FedDGCN(nn.Module): output2 = self.end_conv3(output2) return output1 + output2 + + +class FederatedFedDGCN(nn.Module): + def __init__(self, args): + super(FederatedFedDGCN, self).__init__() + + # Initializing with None, we will populate model_list during the forward pass + self.model_list = None + self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation) + self.args = args + self.subgraph_num = 0 + self.num_node = args.minigraph_size + self.input_dim = args.input_dim + self.hidden_dim = args.rnn_units + self.output_dim = args.output_dim + self.horizon = args.horizon + + def forward(self, source): + """ + Forward pass for the federated model. Each subgraph processes its portion of the data, + and then the results are aggregated. + + Arguments: + - source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims) + + Returns: + - Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims) + """ + self.subgraph_num = source.shape[2] + + # Initialize model_list if it hasn't been initialized yet + if self.model_list is None: + # Initialize model_list with FedDGCN models, one for each subgraph + self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)]) + + # Initialize a list to store the outputs of each subgraph model + subgraph_outputs = [] + + # Iterate through the subgraph models + for i in range(self.subgraph_num): + # Extract the subgraph-specific data + subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims) + + # Forward pass for each subgraph model + subgraph_output = self.model_list[i](subgraph_data) + subgraph_outputs.append(subgraph_output) + + # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) + output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) + + return output_tensor