From 84fddbb8cd909b2880f58de4ffd89ee187d2446d Mon Sep 17 00:00:00 2001 From: HengZhang Date: Wed, 27 Nov 2024 16:17:57 +0800 Subject: [PATCH] rewrite function --- federatedscope/trafficflow/model/DGCRUCell.py | 4 ++-- federatedscope/trafficflow/model/FedDGCN.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/federatedscope/trafficflow/model/DGCRUCell.py b/federatedscope/trafficflow/model/DGCRUCell.py index 299e977..b15f4f9 100644 --- a/federatedscope/trafficflow/model/DGCRUCell.py +++ b/federatedscope/trafficflow/model/DGCRUCell.py @@ -3,9 +3,9 @@ import torch.nn as nn from federatedscope.trafficflow.model.DGCN import DGCN -class DDGCRNCell(nn.Module): +class DGCRUCell(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): - super(DDGCRNCell, self).__init__() + super(DGCRUCell, self).__init__() self.node_num = node_num self.hidden_dim = dim_out self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) diff --git a/federatedscope/trafficflow/model/FedDGCN.py b/federatedscope/trafficflow/model/FedDGCN.py index 763db82..91de12a 100644 --- a/federatedscope/trafficflow/model/FedDGCN.py +++ b/federatedscope/trafficflow/model/FedDGCN.py @@ -1,7 +1,7 @@ from federatedscope.register import register_model import torch import torch.nn as nn -from federatedscope.trafficflow.model.DGCRUCell import DDGCRNCell +from federatedscope.trafficflow.model.DGCRUCell import DGCRUCell class DGCRM(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): @@ -11,9 +11,9 @@ class DGCRM(nn.Module): self.input_dim = dim_in self.num_layers = num_layers self.DGCRM_cells = nn.ModuleList() - self.DGCRM_cells.append(DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) + self.DGCRM_cells.append(DGCRUCell(node_num, dim_in, dim_out, cheb_k, embed_dim)) for _ in range(1, num_layers): - self.DGCRM_cells.append(DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) + self.DGCRM_cells.append(DGCRUCell(node_num, dim_out, dim_out, cheb_k, embed_dim)) def forward(self, x, init_state, node_embeddings): assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim