TrafficWheel/model/TWDGCN/DGCRU.py

37 lines
1.3 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
from model.TWDGCN.DGCN import DGCN
class DDGCRNCell(
nn.Module
): # 这个模块只进行GRU内部的更新所以需要修改的是AGCN里面的东西
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(DDGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim)
self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings, connMTX):
# 这里的x是单步时间步的
# x: B, num_nodes, input_dim
# state: B, num_nodes, hidden_dim
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(
self.gate(
input_and_state, node_embeddings, connMTX.to(input_and_state.device)
)
)
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z * state), dim=-1)
hc = torch.tanh(
self.update(candidate, node_embeddings, connMTX.to(input_and_state))
)
h = r * state + (1 - r) * hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)