rewrite function
This commit is contained in:
parent
b335047b87
commit
84fddbb8cd
|
|
@ -3,9 +3,9 @@ import torch.nn as nn
|
||||||
from federatedscope.trafficflow.model.DGCN import DGCN
|
from federatedscope.trafficflow.model.DGCN import DGCN
|
||||||
|
|
||||||
|
|
||||||
class DDGCRNCell(nn.Module):
|
class DGCRUCell(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||||||
super(DDGCRNCell, self).__init__()
|
super(DGCRUCell, self).__init__()
|
||||||
self.node_num = node_num
|
self.node_num = node_num
|
||||||
self.hidden_dim = dim_out
|
self.hidden_dim = dim_out
|
||||||
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from federatedscope.register import register_model
|
from federatedscope.register import register_model
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from federatedscope.trafficflow.model.DGCRUCell import DDGCRNCell
|
from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell
|
||||||
|
|
||||||
class DGCRM(nn.Module):
|
class DGCRM(nn.Module):
|
||||||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
||||||
|
|
@ -11,9 +11,9 @@ class DGCRM(nn.Module):
|
||||||
self.input_dim = dim_in
|
self.input_dim = dim_in
|
||||||
self.num_layers = num_layers
|
self.num_layers = num_layers
|
||||||
self.DGCRM_cells = nn.ModuleList()
|
self.DGCRM_cells = nn.ModuleList()
|
||||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
||||||
for _ in range(1, num_layers):
|
for _ in range(1, num_layers):
|
||||||
self.DGCRM_cells.append(DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim))
|
||||||
|
|
||||||
def forward(self, x, init_state, node_embeddings):
|
def forward(self, x, init_state, node_embeddings):
|
||||||
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue