rewrite function
This commit is contained in:
parent
b335047b87
commit
84fddbb8cd
|
|
@ -3,9 +3,9 @@ import torch.nn as nn
|
|||
from federatedscope.trafficflow.model.DGCN import DGCN
|
||||
|
||||
|
||||
class DDGCRNCell(nn.Module):
|
||||
class DGCRUCell(nn.Module):
|
||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||
super(DDGCRNCell, self).__init__()
|
||||
super(DGCRUCell, self).__init__()
|
||||
self.node_num = node_num
|
||||
self.hidden_dim = dim_out
|
||||
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from federatedscope.register import register_model
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from federatedscope.trafficflow.model.DGCRUCell import DDGCRNCell
|
||||
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||
|
||||
class DGCRM(nn.Module):
|
||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||
|
|
@ -11,9 +11,9 @@ class DGCRM(nn.Module):
|
|||
self.input_dim = dim_in
|
||||
self.num_layers = num_layers
|
||||
self.DGCRM_cells = nn.ModuleList()
|
||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
||||
self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
||||
for _ in range(1, num_layers):
|
||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
||||
self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
||||
|
||||
def forward(self, x, init_state, node_embeddings):
|
||||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||
|
|
|
|||
Loading…
Reference in New Issue