rewrite function

This commit is contained in:
HengZhang 2024-11-27 16:17:57 +08:00
parent b335047b87
commit 84fddbb8cd
2 changed files with 5 additions and 5 deletions

View File

@ -3,9 +3,9 @@ import torch.nn as nn
from federatedscope.trafficflow.model.DGCN import DGCN from federatedscope.trafficflow.model.DGCN import DGCN
class DDGCRNCell(nn.Module): class DGCRUCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__() super(DGCRUCell, self).__init__()
self.node_num = node_num self.node_num = node_num
self.hidden_dim = dim_out self.hidden_dim = dim_out
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)

View File

@ -1,7 +1,7 @@
from federatedscope.register import register_model from federatedscope.register import register_model
import torch import torch
import torch.nn as nn import torch.nn as nn
from federatedscope.trafficflow.model.DGCRUCell import DDGCRNCell from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
class DGCRM(nn.Module): class DGCRM(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
@ -11,9 +11,9 @@ class DGCRM(nn.Module):
self.input_dim = dim_in self.input_dim = dim_in
self.num_layers = num_layers self.num_layers = num_layers
self.DGCRM_cells = nn.ModuleList() self.DGCRM_cells = nn.ModuleList()
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
for _ in range(1, num_layers): for _ in range(1, num_layers):
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
def forward(self, x, init_state, node_embeddings): def forward(self, x, init_state, node_embeddings):
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim