30 lines
1.3 KiB
Python
30 lines
1.3 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
from model.TWDGCN.DGCN import DGCN
|
||
|
||
|
||
class DDGCRNCell(nn.Module): # 这个模块只进行GRU内部的更新,所以需要修改的是AGCN里面的东西
|
||
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
|
||
super(DDGCRNCell, self).__init__()
|
||
self.node_num = node_num
|
||
self.hidden_dim = dim_out
|
||
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
|
||
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
|
||
|
||
def forward(self, x, state, node_embeddings, connMTX):
|
||
# 这里的x是单步时间步的
|
||
# x: B, num_nodes, input_dim
|
||
# state: B, num_nodes, hidden_dim
|
||
state = state.to(x.device)
|
||
input_and_state = torch.cat((x, state), dim=-1)
|
||
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings, connMTX.to(input_and_state.device)))
|
||
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
|
||
candidate = torch.cat((x, z * state), dim=-1)
|
||
hc = torch.tanh(self.update(candidate, node_embeddings, connMTX.to(input_and_state)))
|
||
h = r * state + (1 - r) * hc
|
||
return h
|
||
|
||
def init_hidden_state(self, batch_size):
|
||
return torch.zeros(batch_size, self.node_num, self.hidden_dim)
|
||
|