finish v2 model output

This commit is contained in:
HengZhang 2024-11-28 10:35:31 +08:00
parent 1b25920188
commit 0b41f04d3c
2 changed files with 58 additions and 5 deletions

View File

@ -209,8 +209,8 @@ def get_model(model_config, local_data=None, backend='torch'):
from federatedscope.trafficflow.model.FedDGCN import FedDGCN from federatedscope.trafficflow.model.FedDGCN import FedDGCN
model = FedDGCN(model_config) model = FedDGCN(model_config)
else: else:
from federatedscope.trafficflow.model.FedDGCNv2 import FedDGCN from federatedscope.trafficflow.model.FedDGCNv2 import FederatedFedDGCN
model = FedDGCN(model_config) model = FederatedFedDGCN(model_config)
else: else:
raise ValueError('Model {} is not provided'.format(model_config.type)) raise ValueError('Model {} is not provided'.format(model_config.type))

View File

@ -1,3 +1,5 @@
from torch.nn import ModuleList
from federatedscope.register import register_model from federatedscope.register import register_model
import torch import torch
import torch.nn as nn import torch.nn as nn
@ -41,7 +43,7 @@ class FedDGCN(nn.Module):
def __init__(self, args): def __init__(self, args):
super(FedDGCN, self).__init__() super(FedDGCN, self).__init__()
# print("You are in subminigraph") # print("You are in subminigraph")
self.num_node = args.num_nodes self.num_node = args.minigraph_size
self.input_dim = args.input_dim self.input_dim = args.input_dim
self.hidden_dim = args.rnn_units self.hidden_dim = args.rnn_units
self.output_dim = args.output_dim self.output_dim = args.output_dim
@ -60,9 +62,9 @@ class FedDGCN(nn.Module):
nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.T_i_D_emb)
nn.init.xavier_uniform_(self.D_i_W_emb) nn.init.xavier_uniform_(self.D_i_W_emb)
self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
args.embed_dim, args.num_layers) args.embed_dim, args.num_layers)
self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
args.embed_dim, args.num_layers) args.embed_dim, args.num_layers)
# predictor # predictor
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
@ -100,3 +102,54 @@ class FedDGCN(nn.Module):
output2 = self.end_conv3(output2) output2 = self.end_conv3(output2)
return output1 + output2 return output1 + output2
class FederatedFedDGCN(nn.Module):
def __init__(self, args):
super(FederatedFedDGCN, self).__init__()
# Initializing with None, we will populate model_list during the forward pass
self.model_list = None
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation)
self.args = args
self.subgraph_num = 0
self.num_node = args.minigraph_size
self.input_dim = args.input_dim
self.hidden_dim = args.rnn_units
self.output_dim = args.output_dim
self.horizon = args.horizon
def forward(self, source):
"""
Forward pass for the federated model. Each subgraph processes its portion of the data,
and then the results are aggregated.
Arguments:
- source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims)
Returns:
- Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims)
"""
self.subgraph_num = source.shape[2]
# Initialize model_list if it hasn't been initialized yet
if self.model_list is None:
# Initialize model_list with FedDGCN models, one for each subgraph
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
# Initialize a list to store the outputs of each subgraph model
subgraph_outputs = []
# Iterate through the subgraph models
for i in range(self.subgraph_num):
# Extract the subgraph-specific data
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
# Forward pass for each subgraph model
subgraph_output = self.model_list[i](subgraph_data)
subgraph_outputs.append(subgraph_output)
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
return output_tensor