finish v2 model output
This commit is contained in:
parent
1b25920188
commit
0b41f04d3c
|
|
@ -209,8 +209,8 @@ def get_model(model_config, local_data=None, backend='torch'):
|
|||
from federatedscope.trafficflow.model.FedDGCN import FedDGCN
|
||||
model = FedDGCN(model_config)
|
||||
else:
|
||||
from federatedscope.trafficflow.model.FedDGCNv2 import FedDGCN
|
||||
model = FedDGCN(model_config)
|
||||
from federatedscope.trafficflow.model.FedDGCNv2 import FederatedFedDGCN
|
||||
model = FederatedFedDGCN(model_config)
|
||||
else:
|
||||
raise ValueError('Model {} is not provided'.format(model_config.type))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from torch.nn import ModuleList
|
||||
|
||||
from federatedscope.register import register_model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
|
@ -41,7 +43,7 @@ class FedDGCN(nn.Module):
|
|||
def __init__(self, args):
|
||||
super(FedDGCN, self).__init__()
|
||||
# print("You are in subminigraph")
|
||||
self.num_node = args.num_nodes
|
||||
self.num_node = args.minigraph_size
|
||||
self.input_dim = args.input_dim
|
||||
self.hidden_dim = args.rnn_units
|
||||
self.output_dim = args.output_dim
|
||||
|
|
@ -60,9 +62,9 @@ class FedDGCN(nn.Module):
|
|||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||
|
||||
self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
args.embed_dim, args.num_layers)
|
||||
self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
||||
args.embed_dim, args.num_layers)
|
||||
# predictor
|
||||
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||
|
|
@ -100,3 +102,54 @@ class FedDGCN(nn.Module):
|
|||
output2 = self.end_conv3(output2)
|
||||
|
||||
return output1 + output2
|
||||
|
||||
|
||||
class FederatedFedDGCN(nn.Module):
|
||||
def __init__(self, args):
|
||||
super(FederatedFedDGCN, self).__init__()
|
||||
|
||||
# Initializing with None, we will populate model_list during the forward pass
|
||||
self.model_list = None
|
||||
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation)
|
||||
self.args = args
|
||||
self.subgraph_num = 0
|
||||
self.num_node = args.minigraph_size
|
||||
self.input_dim = args.input_dim
|
||||
self.hidden_dim = args.rnn_units
|
||||
self.output_dim = args.output_dim
|
||||
self.horizon = args.horizon
|
||||
|
||||
def forward(self, source):
|
||||
"""
|
||||
Forward pass for the federated model. Each subgraph processes its portion of the data,
|
||||
and then the results are aggregated.
|
||||
|
||||
Arguments:
|
||||
- source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
|
||||
Returns:
|
||||
- Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
"""
|
||||
self.subgraph_num = source.shape[2]
|
||||
|
||||
# Initialize model_list if it hasn't been initialized yet
|
||||
if self.model_list is None:
|
||||
# Initialize model_list with FedDGCN models, one for each subgraph
|
||||
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
|
||||
|
||||
# Initialize a list to store the outputs of each subgraph model
|
||||
subgraph_outputs = []
|
||||
|
||||
# Iterate through the subgraph models
|
||||
for i in range(self.subgraph_num):
|
||||
# Extract the subgraph-specific data
|
||||
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
||||
|
||||
# Forward pass for each subgraph model
|
||||
subgraph_output = self.model_list[i](subgraph_data)
|
||||
subgraph_outputs.append(subgraph_output)
|
||||
|
||||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||
|
||||
return output_tensor
|
||||
|
|
|
|||
Loading…
Reference in New Issue