train success
This commit is contained in:
parent
9cfe3f01dc
commit
e95c13f0fc
|
|
@ -3,6 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||||
|
|
||||||
|
|
||||||
class DGCRM(nn.Module):
|
class DGCRM(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||||
super(DGCRM, self).__init__()
|
super(DGCRM, self).__init__()
|
||||||
|
|
@ -24,7 +25,8 @@ class DGCRM(nn.Module):
|
||||||
state = init_state[i]
|
state = init_state[i]
|
||||||
inner_states = []
|
inner_states = []
|
||||||
for t in range(seq_length):
|
for t in range(seq_length):
|
||||||
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
state = self.DGCRM_cells[i](current_inputs[:, t, :, :], state,
|
||||||
|
[node_embeddings[0][:, t, :, :], node_embeddings[1]])
|
||||||
inner_states.append(state)
|
inner_states.append(state)
|
||||||
output_hidden.append(state)
|
output_hidden.append(state)
|
||||||
current_inputs = torch.stack(inner_states, dim=1)
|
current_inputs = torch.stack(inner_states, dim=1)
|
||||||
|
|
@ -36,6 +38,7 @@ class DGCRM(nn.Module):
|
||||||
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
init_states.append(self.DGCRM_cells[i].init_hidden_state(batch_size))
|
||||||
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
||||||
|
|
||||||
|
|
||||||
# Build you torch or tf model class here
|
# Build you torch or tf model class here
|
||||||
class FedDGCN(nn.Module):
|
class FedDGCN(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
|
|
@ -107,15 +110,11 @@ class FederatedFedDGCN(nn.Module):
|
||||||
super(FederatedFedDGCN, self).__init__()
|
super(FederatedFedDGCN, self).__init__()
|
||||||
|
|
||||||
# Initializing with None, we will populate model_list during the forward pass
|
# Initializing with None, we will populate model_list during the forward pass
|
||||||
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
self.model_list = None
|
self.model_list = None
|
||||||
self.main_model = FedDGCN(args) # Initialize a single FedDGCN model (for aggregation)
|
self.graph_num = (args.num_nodes + args.minigraph_size - 1) // args.minigraph_size
|
||||||
self.args = args
|
self.args = args
|
||||||
self.subgraph_num = 0
|
self.model_list = ModuleList(FedDGCN(self.args).to(self.device) for _ in range(self.graph_num))
|
||||||
self.num_node = args.minigraph_size
|
|
||||||
self.input_dim = args.input_dim
|
|
||||||
self.hidden_dim = args.rnn_units
|
|
||||||
self.output_dim = args.output_dim
|
|
||||||
self.horizon = args.horizon
|
|
||||||
|
|
||||||
def forward(self, source):
|
def forward(self, source):
|
||||||
"""
|
"""
|
||||||
|
|
@ -130,11 +129,6 @@ class FederatedFedDGCN(nn.Module):
|
||||||
"""
|
"""
|
||||||
self.subgraph_num = source.shape[2]
|
self.subgraph_num = source.shape[2]
|
||||||
|
|
||||||
# Initialize model_list if it hasn't been initialized yet
|
|
||||||
if self.model_list is None:
|
|
||||||
# Initialize model_list with FedDGCN models, one for each subgraph
|
|
||||||
self.model_list = ModuleList([self.main_model] + [FedDGCN(self.args) for _ in range(self.subgraph_num - 1)])
|
|
||||||
|
|
||||||
# Initialize a list to store the outputs of each subgraph model
|
# Initialize a list to store the outputs of each subgraph model
|
||||||
subgraph_outputs = []
|
subgraph_outputs = []
|
||||||
|
|
||||||
|
|
@ -150,7 +144,7 @@ class FederatedFedDGCN(nn.Module):
|
||||||
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
# Reshape the outputs into (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
output_tensor = torch.stack(subgraph_outputs, dim=2) # (batchsize, horizon, subgraph_num, subgraph_size, dims)
|
||||||
|
|
||||||
self.update_main_model()
|
# self.update_main_model()
|
||||||
|
|
||||||
return output_tensor
|
return output_tensor
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue