train success

This commit is contained in:
HengZhang 2024-11-28 11:15:47 +08:00
parent 9cfe3f01dc
commit e95c13f0fc
1 changed files with 13 additions and 19 deletions

View File

@ -3,6 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
class DGCRM(nn.Module): class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
super(DGCRM, self).__init__() super(DGCRM, self).__init__()
@ -24,7 +25,8 @@ class DGCRM(nn.Module):
state = init_state[i] state = init_state[i]
inner_states = [] inner_states = []
for t in range(seq_length): for t in range(seq_length):
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
inner_states.append(state) inner_states.append(state)
output_hidden.append(state) output_hidden.append(state)
current_inputs = torch.stack(inner_states, dim=1) current_inputs = torch.stack(inner_states, dim=1)
@ -34,7 +36,8 @@ class DGCRM(nn.Module):
init_states = [] init_states = []
for i in range(self.num_layers): for i in range(self.num_layers):
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size)) init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
return torch.stack(init_states, dim=0) #(num_layers, B, N, hidden_dim) return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
# Build you torch or tf model class here # Build you torch or tf model class here
class FedDGCN(nn.Module): class FedDGCN(nn.Module):
@ -49,7 +52,7 @@ class FedDGCN(nn.Module):
self.num_layers = args.num_layers self.num_layers = args.num_layers
self.use_D = args.use_day self.use_D = args.use_day
self.use_W = args.use_week self.use_W = args.use_week
self.dropout1 = nn.Dropout(p=args.dropout) # 0.1 self.dropout1 = nn.Dropout(p=args.dropout) # 0.1
self.dropout2 = nn.Dropout(p=args.dropout) self.dropout2 = nn.Dropout(p=args.dropout)
self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings1 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True) self.node_embeddings2 = nn.Parameter(torch.randn(self.num_node, args.embed_dim), requires_grad=True)
@ -72,16 +75,16 @@ class FedDGCN(nn.Module):
def forward(self, source): def forward(self, source):
node_embedding1 = self.node_embeddings1 node_embedding1 = self.node_embeddings1
if self.use_D: if self.use_D:
t_i_d_data = source[..., 1] t_i_d_data = source[..., 1]
T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)] T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).type(torch.LongTensor)]
node_embedding1 = torch.mul(node_embedding1, T_i_D_emb) node_embedding1 = torch.mul(node_embedding1, T_i_D_emb)
if self.use_W: if self.use_W:
d_i_w_data = source[..., 2] d_i_w_data = source[..., 2]
D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)] D_i_W_emb = self.D_i_W_emb[(d_i_w_data).type(torch.LongTensor)]
node_embedding1 = torch.mul(node_embedding1, D_i_W_emb) node_embedding1 = torch.mul(node_embedding1, D_i_W_emb)
node_embeddings=[node_embedding1,self.node_embeddings1] node_embeddings = [node_embedding1, self.node_embeddings1]
source = source[..., 0].unsqueeze(-1) source = source[..., 0].unsqueeze(-1)
@ -107,15 +110,11 @@ class FederatedFedDGCN(nn.Module):
super(FederatedFedDGCN, self).__init__() super(FederatedFedDGCN, self).__init__()
# Initializing with None, we will populate model_list during the forward pass # Initializing with None, we will populate model_list during the forward pass
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model_list = None self.model_list = None
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation) self.graph_num = (args.num_nodes + args.minigraph_size - 1) // args.minigraph_size
self.args = args self.args = args
self.subgraph_num = 0 self.model_list = ModuleList(FedDGCN(self.args).to(self.device) for _ in range(self.graph_num))
self.num_node = args.minigraph_size
self.input_dim = args.input_dim
self.hidden_dim = args.rnn_units
self.output_dim = args.output_dim
self.horizon = args.horizon
def forward(self, source): def forward(self, source):
""" """
@ -130,11 +129,6 @@ class FederatedFedDGCN(nn.Module):
""" """
self.subgraph_num = source.shape[2] self.subgraph_num = source.shape[2]
# Initialize model_list if it hasn't been initialized yet
if self.model_list is None:
# Initialize model_list with FedDGCN models, one for each subgraph
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
# Initialize a list to store the outputs of each subgraph model # Initialize a list to store the outputs of each subgraph model
subgraph_outputs = [] subgraph_outputs = []
@ -150,7 +144,7 @@ class FederatedFedDGCN(nn.Module):
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims) # Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims) output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
self.update_main_model() # self.update_main_model()
return output_tensor return output_tensor