From d016dd598092f0b61125b56e8fb0d3d7c5c3f674 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 27 Mar 2025 20:07:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=B6=88=E8=80=97=E5=88=86?= =?UTF-8?q?=E6=9E=90=E6=A8=A1=E5=BC=8F=EF=BC=8C=E5=8F=AA=E9=9C=80=E5=9C=A8?= =?UTF-8?q?=E5=8E=9F=E6=9C=89=E7=9A=84mode=E4=B8=AD=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E4=B8=BAbenchmark=E5=8D=B3=E5=8F=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/DDGCRN/DDGCRN.py | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/model/DDGCRN/DDGCRN.py b/model/DDGCRN/DDGCRN.py index 237a20e..cdf1cc9 100644 --- a/model/DDGCRN/DDGCRN.py +++ b/model/DDGCRN/DDGCRN.py @@ -7,14 +7,13 @@ class DGCRM(nn.Module): super().__init__() self.node_num, self.input_dim, self.num_layers = node_num, dim_in, num_layers self.cells = nn.ModuleList( - [DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in - range(num_layers)]) + [DDGCRNCell(node_num, dim_in if i == 0 else dim_out, dim_out, cheb_k, embed_dim) for i in range(num_layers)] + ) def forward(self, x, init_state, node_embeddings): assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim for i in range(self.num_layers): - state, inner = init_state[i], [] - state = state.to(x.device) + state, inner = init_state[i].to(x.device), [] for t in range(x.shape[1]): state = self.cells[i](x[:, t, :, :], state, [node_embeddings[0][:, t, :, :], node_embeddings[1]]) inner.append(state) @@ -69,8 +68,8 @@ class DDGCRNCell(nn.Module): def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): super().__init__() self.node_num, self.hidden_dim = node_num, dim_out - self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim) - self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim) + self.gate = DGCN(dim_in + dim_out, 2 * dim_out, cheb_k, embed_dim, node_num) + self.update = DGCN(dim_in + dim_out, dim_out, cheb_k, embed_dim, node_num) def forward(self, x, state, node_embeddings): inp = torch.cat((x, state), -1) @@ -84,7 +83,7 @@ class DDGCRNCell(nn.Module): class DGCN(nn.Module): - def __init__(self, dim_in, dim_out, cheb_k, embed_dim): + def __init__(self, dim_in, dim_out, cheb_k, embed_dim, num_nodes): super().__init__() self.cheb_k, self.embed_dim = cheb_k, embed_dim self.weights_pool = nn.Parameter(torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out)) @@ -98,14 +97,16 @@ class DGCN(nn.Module): ('sigmoid2', nn.Sigmoid()), ('fc3', nn.Linear(2, embed_dim)) ])) + # 预注册恒定不变的单位矩阵 + self.register_buffer('eye', torch.eye(num_nodes)) def forward(self, x, node_embeddings): - node_num = node_embeddings[0].shape[1] - supp1 = torch.eye(node_num).to(node_embeddings[0].device) + supp1 = self.eye.to(node_embeddings[0].device) filt = self.fc(x) nodevec = torch.tanh(node_embeddings[0] * filt) supp2 = self.get_laplacian(F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supp1) - x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1) + x_g = torch.stack([torch.einsum("nm,bmc->bnc", supp1, x), + torch.einsum("bnm,bmc->bnc", supp2, x)], dim=1) weights = torch.einsum('nd,dkio->nkio', node_embeddings[1], self.weights_pool) bias = torch.matmul(node_embeddings[1], self.bias_pool) return torch.einsum('bnki,nkio->bno', x_g.permute(0, 2, 1, 3), weights) + bias