156 lines
6.7 KiB
Python
156 lines
6.7 KiB
Python
from torch.nn import ModuleList
|
|
|
|
from federatedscope.register import register_model
|
|
import torch
|
|
import torch.nn as nn
|
|
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
|
|
|
class DGCRM(nn.Module):
|
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
|
super(DGCRM, self).__init__()
|
|
assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.'
|
|
self.node_num = node_num
|
|
self.input_dim = dim_in
|
|
self.num_layers = num_layers
|
|
self.DGCRM_cells = nn.ModuleList()
|
|
self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
|
for _ in range(1, num_layers):
|
|
self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
|
|
|
def forward(self, x, init_state, node_embeddings):
|
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
|
seq_length = x.shape[1]
|
|
current_inputs = x
|
|
output_hidden = []
|
|
for i in range(self.num_layers):
|
|
state = init_state[i]
|
|
inner_states = []
|
|
for t in range(seq_length):
|
|
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
|
inner_states.append(state)
|
|
output_hidden.append(state)
|
|
current_inputs = torch.stack(inner_states, dim=1)
|
|
return current_inputs, output_hidden
|
|
|
|
def init_hidden(self, batch_size):
|
|
init_states = []
|
|
for i in range(self.num_layers):
|
|
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
|
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim)
|
|
|
|
# Build you torch or tf model class here
|
|
class FedDGCN(nn.Module):
|
|
def __init__(self, args):
|
|
super(FedDGCN, self).__init__()
|
|
# print("You are in subminigraph")
|
|
self.num_node = args.minigraph_size
|
|
self.input_dim = args.input_dim
|
|
self.hidden_dim = args.rnn_units
|
|
self.output_dim = args.output_dim
|
|
self.horizon = args.horizon
|
|
self.num_layers = args.num_layers
|
|
self.use_D = args.use_day
|
|
self.use_W = args.use_week
|
|
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
|
|
self.dropout2 = nn.Dropout(p=args.dropout)
|
|
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
|
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
|
|
self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim))
|
|
self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim))
|
|
# Initialize parameters
|
|
nn.init.xavier_uniform_(self.node_embeddings1)
|
|
nn.init.xavier_uniform_(self.T_i_D_emb)
|
|
nn.init.xavier_uniform_(self.D_i_W_emb)
|
|
|
|
self.encoder1 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
|
args.embed_dim, args.num_layers)
|
|
self.encoder2 = DGCRM(args.minigraph_size, args.input_dim, args.rnn_units, args.cheb_order,
|
|
args.embed_dim, args.num_layers)
|
|
# predictor
|
|
self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True)
|
|
|
|
def forward(self, source, i=2):
|
|
node_embedding1 = self.node_embeddings1
|
|
if self.use_D:
|
|
t_i_d_data = source[..., 1]
|
|
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
|
|
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
|
|
|
|
if self.use_W:
|
|
d_i_w_data = source[..., 2]
|
|
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
|
|
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
|
|
|
|
node_embeddings=[node_embedding1,self.node_embeddings1]
|
|
|
|
source = source[..., 0].unsqueeze(-1)
|
|
|
|
init_state1 = self.encoder1.init_hidden(source.shape[0])
|
|
output, _ = self.encoder1(source, init_state1, node_embeddings)
|
|
output = self.dropout1(output[:, -1:, :, :])
|
|
|
|
output1 = self.end_conv1(output)
|
|
source1 = self.end_conv2(output)
|
|
|
|
source2 = source - source1
|
|
|
|
init_state2 = self.encoder2.init_hidden(source2.shape[0])
|
|
output2, _ = self.encoder2(source2, init_state2, node_embeddings)
|
|
output2 = self.dropout2(output2[:, -1:, :, :])
|
|
output2 = self.end_conv3(output2)
|
|
|
|
return output1 + output2
|
|
|
|
|
|
class FederatedFedDGCN(nn.Module):
|
|
def __init__(self, args):
|
|
super(FederatedFedDGCN, self).__init__()
|
|
|
|
# Initializing with None, we will populate model_list during the forward pass
|
|
self.model_list = None
|
|
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation)
|
|
self.args = args
|
|
self.subgraph_num = 0
|
|
self.num_node = args.minigraph_size
|
|
self.input_dim = args.input_dim
|
|
self.hidden_dim = args.rnn_units
|
|
self.output_dim = args.output_dim
|
|
self.horizon = args.horizon
|
|
|
|
def forward(self, source):
|
|
"""
|
|
Forward pass for the federated model. Each subgraph processes its portion of the data,
|
|
and then the results are aggregated.
|
|
|
|
Arguments:
|
|
- source: Tensor of shape (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
|
|
|
Returns:
|
|
- Aggregated output (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
|
"""
|
|
self.subgraph_num = source.shape[2]
|
|
|
|
# Initialize model_list if it hasn't been initialized yet
|
|
if self.model_list is None:
|
|
# Initialize model_list with FedDGCN models, one for each subgraph
|
|
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
|
|
|
|
# Initialize a list to store the outputs of each subgraph model
|
|
subgraph_outputs = []
|
|
|
|
# Iterate through the subgraph models
|
|
for i in range(self.subgraph_num):
|
|
# Extract the subgraph-specific data
|
|
subgraph_data = source[:, :, i, :, :] # (batchsize, horizon, subgraph_size, dims)
|
|
|
|
# Forward pass for each subgraph model
|
|
subgraph_output = self.model_list[i](subgraph_data)
|
|
subgraph_outputs.append(subgraph_output)
|
|
|
|
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
|
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
|
|
|
return output_tensor
|