diff --git a/federatedscope/core/auxiliaries/model_builder.py b/federatedscope/core/auxiliaries/model_builder.py index 25cc31d..22bfa39 100644 --- a/federatedscope/core/auxiliaries/model_builder.py +++ b/federatedscope/core/auxiliaries/model_builder.py @@ -205,8 +205,12 @@ def get_model(model_config, local_data=None, backend='torch'): from federatedscope.nlp.hetero_tasks.model import ATCModel model = ATCModel(model_config) elif model_config.type.lower() in ['feddgcn']: - from federatedscope.trafficflow.model.FedDGCN import FedDGCN - model = FedDGCN(model_config) + if model_config.use_minigraph is False: + from federatedscope.trafficflow.model.FedDGCN import FedDGCN + model = FedDGCN(model_config) + else: + from federatedscope.trafficflow.model.FedDGCNv2 import FedDGCN + model = FedDGCN(model_config) else: raise ValueError('Model {} is not provided'.format(model_config.type)) diff --git a/federatedscope/trafficflow/model/FedDGCNv2.py b/federatedscope/trafficflow/model/FedDGCNv2.py new file mode 100644 index 0000000..927a509 --- /dev/null +++ b/federatedscope/trafficflow/model/FedDGCNv2.py @@ -0,0 +1,102 @@ +from federatedscope.register import register_model +import torch +import torch.nn as nn +from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell + +class DGCRM(nn.Module): + def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): + super(DGCRM, self).__init__() + assert num_layers >= 1, 'At least one DCRNN layer in the Encoder.' + self.node_num = node_num + self.input_dim = dim_in + self.num_layers = num_layers + self.DGCRM_cells = nn.ModuleList() + self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) + for _ in range(1, num_layers): + self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) + + def forward(self, x, init_state, node_embeddings): + assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim + seq_length = x.shape[1] + current_inputs = x + output_hidden = [] + for i in range(self.num_layers): + state = init_state[i] + inner_states = [] + for t in range(seq_length): + state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) + inner_states.append(state) + output_hidden.append(state) + current_inputs = torch.stack(inner_states, dim=1) + return current_inputs, output_hidden + + def init_hidden(self, batch_size): + init_states = [] + for i in range(self.num_layers): + init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) + return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) + +# Build you torch or tf model class here +class FedDGCN(nn.Module): + def __init__(self, args): + super(FedDGCN, self).__init__() + # print("You are in subminigraph") + self.num_node = args.num_nodes + self.input_dim = args.input_dim + self.hidden_dim = args.rnn_units + self.output_dim = args.output_dim + self.horizon = args.horizon + self.num_layers = args.num_layers + self.use_D = args.use_day + self.use_W = args.use_week + self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 + self.dropout2 = nn.Dropout(p=args.dropout) + self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) + self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) + self.T_i_D_emb = nn.Parameter(torch.empty(288, args.embed_dim)) + self.D_i_W_emb = nn.Parameter(torch.empty(7, args.embed_dim)) + # Initialize parameters + nn.init.xavier_uniform_(self.node_embeddings1) + nn.init.xavier_uniform_(self.T_i_D_emb) + nn.init.xavier_uniform_(self.D_i_W_emb) + + self.encoder1 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, + args.embed_dim, args.num_layers) + self.encoder2 = DGCRM(args.num_nodes, args.input_dim, args.rnn_units, args.cheb_order, + args.embed_dim, args.num_layers) + # predictor + self.end_conv1 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) + self.end_conv2 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) + self.end_conv3 = nn.Conv2d(1, args.horizon * self.output_dim, kernel_size=(1, self.hidden_dim), bias=True) + + def forward(self, source, i=2): + node_embedding1 = self.node_embeddings1 + if self.use_D: + t_i_d_data = source[..., 1] + T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] + node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) + + if self.use_W: + d_i_w_data = source[..., 2] + D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] + node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) + + node_embeddings=[node_embedding1,self.node_embeddings1] + + source = source[..., 0].unsqueeze(-1) + + init_state1 = self.encoder1.init_hidden(source.shape[0]) + output, _ = self.encoder1(source, init_state1, node_embeddings) + output = self.dropout1(output[:, -1:, :, :]) + + output1 = self.end_conv1(output) + source1 = self.end_conv2(output) + + source2 = source - source1 + + init_state2 = self.encoder2.init_hidden(source2.shape[0]) + output2, _ = self.encoder2(source2, init_state2, node_embeddings) + output2 = self.dropout2(output2[:, -1:, :, :]) + output2 = self.end_conv3(output2) + + return output1 + output2