TrafficWheel/model/AGCRN/AGCRNCell.py

26 lines
1.0 KiB
Python

import torch
import torch.nn as nn
from model.AGCRN.AGCN import AVWGCN
class AGCRNCell(nn.Module):
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim):
super(AGCRNCell, self).__init__()
self.node_num = node_num
self.hidden_dim = dim_out
self.gate = AVWGCN(dim_in+self.hidden_dim, 2*dim_out, cheb_k, embed_dim)
self.update = AVWGCN(dim_in+self.hidden_dim, dim_out, cheb_k, embed_dim)
def forward(self, x, state, node_embeddings):
#x: B, num_nodes, input_dim
#state: B, num_nodes, hidden_dim
state = state.to(x.device)
input_and_state = torch.cat((x, state), dim=-1)
z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings))
z, r = torch.split(z_r, self.hidden_dim, dim=-1)
candidate = torch.cat((x, z*state), dim=-1)
hc = torch.tanh(self.update(candidate, node_embeddings))
h = r*state + (1-r)*hc
return h
def init_hidden_state(self, batch_size):
return torch.zeros(batch_size, self.node_num, self.hidden_dim)