finish v2 model output
This commit is contained in:
parent
1b25920188
commit
0b41f04d3c
|
|
@ -209,8 +209,8 @@ def get_model(model_config, local_data=None, backend='torch'):
|
||||||
from federatedscope.trafficflow.model.FedDGCN import FedDGCN
|
from federatedscope.trafficflow.model.FedDGCN import FedDGCN
|
||||||
model = FedDGCN(model_config)
|
model = FedDGCN(model_config)
|
||||||
else:
|
else:
|
||||||
from federatedscope.trafficflow.model.FedDGCNv2 import FedDGCN
|
from federatedscope.trafficflow.model.FedDGCNv2 import FederatedFedDGCN
|
||||||
model = FedDGCN(model_config)
|
model = FederatedFedDGCN(model_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError('Model {} is not provided'.format(model_config.type))
|
raise ValueError('Model {} is not provided'.format(model_config.type))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from torch.nn import ModuleList
|
||||||
|
|
||||||
from federatedscope.register import register_model
|
from federatedscope.register import register_model
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -41,7 +43,7 @@ class FedDGCN(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(FedDGCN, self).__init__()
|
super(FedDGCN, self).__init__()
|
||||||
# print("You are in subminigraph")
|
# print("You are in subminigraph")
|
||||||
self.num_node = args.num_nodes
|
self.num_node = args.minigraph_size
|
||||||
self.input_dim = args.input_dim
|
self.input_dim = args.input_dim
|
||||||
self.hidden_dim = args.rnn_units
|
self.hidden_dim = args.rnn_units
|
||||||
self.output_dim = args.output_dim
|
self.output_dim = args.output_dim
|
||||||
|
|
@ -60,9 +62,9 @@ class FedDGCN(nn.Module):
|
||||||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||||||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||||||
|
|
||||||
self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
||||||
args.embed_dim, args.num_layers)
|
args.embed_dim, args.num_layers)
|
||||||
self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order,
|
self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
||||||
args.embed_dim, args.num_layers)
|
args.embed_dim, args.num_layers)
|
||||||
# predictor
|
# predictor
|
||||||
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
||||||
|
|
@ -100,3 +102,54 @@ class FedDGCN(nn.Module):
|
||||||
output2 = self.end_conv3(output2)
|
output2 = self.end_conv3(output2)
|
||||||
|
|
||||||
return output1 + output2
|
return output1 + output2
|
||||||
|
|
||||||
|
|
||||||
|
class FederatedFedDGCN(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(FederatedFedDGCN, self).__init__()
|
||||||
|
|
||||||
|
# Initializing with None, we will populate model_list during the forward pass
|
||||||
|
self.model_list = None
|
||||||
|
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation)
|
||||||
|
self.args = args
|
||||||
|
self.subgraph_num = 0
|
||||||
|
self.num_node = args.minigraph_size
|
||||||
|
self.input_dim = args.input_dim
|
||||||
|
self.hidden_dim = args.rnn_units
|
||||||
|
self.output_dim = args.output_dim
|
||||||
|
self.horizon = args.horizon
|
||||||
|
|
||||||
|
def forward(self, source):
|
||||||
|
"""
|
||||||
|
Forward pass for the federated model. Each subgraph processes its portion of the data,
|
||||||
|
and then the results are aggregated.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
- source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
- Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
"""
|
||||||
|
self.subgraph_num = source.shape[2]
|
||||||
|
|
||||||
|
# Initialize model_list if it hasn't been initialized yet
|
||||||
|
if self.model_list is None:
|
||||||
|
# Initialize model_list with FedDGCN models, one for each subgraph
|
||||||
|
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
|
||||||
|
|
||||||
|
# Initialize a list to store the outputs of each subgraph model
|
||||||
|
subgraph_outputs = []
|
||||||
|
|
||||||
|
# Iterate through the subgraph models
|
||||||
|
for i in range(self.subgraph_num):
|
||||||
|
# Extract the subgraph-specific data
|
||||||
|
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
||||||
|
|
||||||
|
# Forward pass for each subgraph model
|
||||||
|
subgraph_output = self.model_list[i](subgraph_data)
|
||||||
|
subgraph_outputs.append(subgraph_output)
|
||||||
|
|
||||||
|
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
|
||||||
|
return output_tensor
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue