100 lines
3.6 KiB
Python
Executable File
100 lines
3.6 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from model.AGCRN.AGCRNCell import AGCRNCell
|
|
|
|
|
|
class AVWDCRNN(nn.Module):
|
|
def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1):
|
|
super(AVWDCRNN, self).__init__()
|
|
assert num_layers >= 1, "At least one DCRNN layer in the Encoder."
|
|
self.node_num = node_num
|
|
self.input_dim = dim_in
|
|
self.num_layers = num_layers
|
|
self.dcrnn_cells = nn.ModuleList()
|
|
self.dcrnn_cells.append(AGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim))
|
|
for _ in range(1, num_layers):
|
|
self.dcrnn_cells.append(
|
|
AGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim)
|
|
)
|
|
|
|
def forward(self, x, init_state, node_embeddings):
|
|
# shape of x: (B, T, N, D)
|
|
# shape of init_state: (num_layers, B, N, hidden_dim)
|
|
assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim
|
|
seq_length = x.shape[1]
|
|
current_inputs = x
|
|
output_hidden = []
|
|
for i in range(self.num_layers):
|
|
state = init_state[i]
|
|
inner_states = []
|
|
for t in range(seq_length):
|
|
state = self.dcrnn_cells[i](
|
|
current_inputs[:, t, :, :], state, node_embeddings
|
|
)
|
|
inner_states.append(state)
|
|
output_hidden.append(state)
|
|
current_inputs = torch.stack(inner_states, dim=1)
|
|
# current_inputs: the outputs of last layer: (B, T, N, hidden_dim)
|
|
# output_hidden: the last state for each layer: (num_layers, B, N, hidden_dim)
|
|
# last_state: (B, N, hidden_dim)
|
|
return current_inputs, output_hidden
|
|
|
|
def init_hidden(self, batch_size):
|
|
init_states = []
|
|
for i in range(self.num_layers):
|
|
init_states.append(self.dcrnn_cells[i].init_hidden_state(batch_size))
|
|
return torch.stack(init_states, dim=0) # (num_layers, B, N, hidden_dim)
|
|
|
|
|
|
class AGCRN(nn.Module):
|
|
def __init__(self, args):
|
|
super(AGCRN, self).__init__()
|
|
self.num_node = args["num_nodes"]
|
|
self.input_dim = args["input_dim"]
|
|
self.hidden_dim = args["rnn_units"]
|
|
self.output_dim = args["output_dim"]
|
|
self.horizon = args["horizon"]
|
|
self.num_layers = args["num_layers"]
|
|
|
|
self.default_graph = args["default_graph"]
|
|
self.node_embeddings = nn.Parameter(
|
|
torch.randn(self.num_node, args["embed_dim"]), requires_grad=True
|
|
)
|
|
|
|
self.encoder = AVWDCRNN(
|
|
args["num_nodes"],
|
|
args["input_dim"],
|
|
args["rnn_units"],
|
|
args["cheb_k"],
|
|
args["embed_dim"],
|
|
args["num_layers"],
|
|
)
|
|
|
|
# predictor
|
|
self.end_conv = nn.Conv2d(
|
|
1,
|
|
args["horizon"] * self.output_dim,
|
|
kernel_size=(1, self.hidden_dim),
|
|
bias=True,
|
|
)
|
|
|
|
def forward(self, source):
|
|
# source: B, T_1, N, D
|
|
# target: B, T_2, N, D
|
|
# supports = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec1.transpose(0,1))), dim=1)
|
|
|
|
init_state = self.encoder.init_hidden(source.shape[0])
|
|
output, _ = self.encoder(
|
|
source[..., :1], init_state, self.node_embeddings
|
|
) # B, T, N, hidden
|
|
output = output[:, -1:, :, :] # B, 1, N, hidden
|
|
|
|
# CNN based predictor
|
|
output = self.end_conv((output)) # B, T*C, N, 1
|
|
output = output.squeeze(-1).reshape(
|
|
-1, self.horizon, self.output_dim, self.num_node
|
|
)
|
|
output = output.permute(0, 1, 3, 2) # B, T, N, C
|
|
|
|
return output
|