145 lines
5.7 KiB
Python
Executable File
145 lines
5.7 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
import math, numpy as np
|
|
|
|
|
|
class HierAttnLstm(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.num_nodes, self.feature_dim, self.output_dim = (
|
|
args["num_nodes"],
|
|
args["feature_dim"],
|
|
args["output_dim"],
|
|
)
|
|
self.input_window, self.output_window = (
|
|
args["input_window"],
|
|
args["output_window"],
|
|
)
|
|
self.hidden_size, self.num_layers = args["hidden_size"], args["num_layers"]
|
|
self.natt_hops, self.nfc, self.max_up_len = (
|
|
args["natt_hops"],
|
|
args["nfc"],
|
|
args["max_up_len"],
|
|
)
|
|
self.input_size = self.num_nodes * self.feature_dim
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
self.lstm_cells = nn.ModuleList(
|
|
[nn.LSTMCell(self.input_size, self.hidden_size)]
|
|
+ [
|
|
nn.LSTMCell(self.hidden_size, self.hidden_size)
|
|
for _ in range(self.num_layers - 1)
|
|
]
|
|
)
|
|
self.hidden_state_pooling = nn.ModuleList(
|
|
[SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)]
|
|
)
|
|
self.cell_state_pooling = nn.ModuleList(
|
|
[SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)]
|
|
)
|
|
self.self_attention = SelfAttention(self.hidden_size, self.natt_hops)
|
|
self.fc_layer = nn.Sequential(
|
|
nn.Linear(self.hidden_size * self.natt_hops, self.nfc),
|
|
nn.ReLU(),
|
|
nn.Linear(self.nfc, self.num_nodes * self.output_dim),
|
|
)
|
|
|
|
def forward(self, batch):
|
|
src, batch_size = batch.permute(1, 0, 2, 3)[..., :1], batch.shape[0]
|
|
src = src.reshape(self.input_window, batch_size, -1)
|
|
|
|
outputs = []
|
|
for i in range(self.output_window):
|
|
hidden_states, cell_states = (
|
|
[
|
|
torch.zeros(batch_size, self.hidden_size, device=self.device)
|
|
for _ in range(self.num_layers)
|
|
],
|
|
[
|
|
torch.zeros(batch_size, self.hidden_size, device=self.device)
|
|
for _ in range(self.num_layers)
|
|
],
|
|
)
|
|
bottom_layer_outputs, cell_states_history = (
|
|
[],
|
|
[[] for _ in range(self.num_layers)],
|
|
)
|
|
|
|
for t in range(self.input_window):
|
|
hidden_states[0], cell_states[0] = self.lstm_cells[0](
|
|
src[t], (hidden_states[0], cell_states[0])
|
|
)
|
|
bottom_layer_outputs.append(hidden_states[0])
|
|
cell_states_history[0].append(cell_states[0])
|
|
|
|
bottom_layer_outputs, cell_states_history[0] = (
|
|
torch.stack(bottom_layer_outputs, 1),
|
|
torch.stack(cell_states_history[0], 1),
|
|
)
|
|
|
|
for layer in range(1, self.num_layers):
|
|
layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs
|
|
layer_outputs, cell_states_history[layer] = [], []
|
|
for start, end in self.calculate_stride(layer_inputs.size(1)):
|
|
segment, cell_segment = (
|
|
layer_inputs[:, start:end, :],
|
|
cell_states_history[layer - 1][:, start:end, :],
|
|
)
|
|
pooled_hidden, pooled_cell = (
|
|
self.hidden_state_pooling[layer - 1](segment),
|
|
self.cell_state_pooling[layer - 1](
|
|
torch.cat(
|
|
[cell_segment, cell_states[layer].unsqueeze(1)], 1
|
|
)
|
|
),
|
|
)
|
|
hidden_states[layer], cell_states[layer] = self.lstm_cells[layer](
|
|
pooled_hidden, (hidden_states[layer], pooled_cell)
|
|
)
|
|
layer_outputs.append(hidden_states[layer])
|
|
cell_states_history[layer].append(cell_states[layer])
|
|
|
|
layer_outputs, cell_states_history[layer] = (
|
|
torch.stack(layer_outputs, 1),
|
|
torch.stack(cell_states_history[layer], 1),
|
|
)
|
|
|
|
attended_features, _ = self.self_attention(layer_outputs)
|
|
out = self.fc_layer(attended_features.view(batch_size, -1)).view(
|
|
batch_size, self.num_nodes, self.output_dim
|
|
)
|
|
outputs.append(out.clone())
|
|
if i < self.output_window - 1:
|
|
src = torch.cat((src[1:], out.reshape(batch_size, -1).unsqueeze(0)), 0)
|
|
|
|
return torch.stack(outputs).permute(1, 0, 2, 3)
|
|
|
|
def calculate_stride(self, seq_len):
|
|
idx = np.linspace(
|
|
0, seq_len - 1, num=min(self.max_up_len, math.ceil(math.sqrt(seq_len))) + 3
|
|
).astype(int)
|
|
return list(zip(np.append(idx, seq_len - 1)[:-1], idx[1:]))
|
|
|
|
|
|
class SelfAttentionPooling(nn.Module):
|
|
def __init__(self, input_dim):
|
|
super().__init__()
|
|
self.W = nn.Linear(input_dim, 1)
|
|
|
|
def forward(self, batch_rep):
|
|
att_w = nn.functional.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(
|
|
-1
|
|
)
|
|
return torch.sum(batch_rep * att_w, dim=1)
|
|
|
|
|
|
class SelfAttention(nn.Module):
|
|
def __init__(self, att_size, att_hops):
|
|
super().__init__()
|
|
self.ut_dense = nn.Sequential(nn.Linear(att_size, att_size), nn.Tanh())
|
|
self.et_dense, self.softmax = nn.Linear(att_size, att_hops), nn.Softmax(dim=-1)
|
|
|
|
def forward(self, inputs):
|
|
att_scores = self.softmax(self.et_dense(self.ut_dense(inputs)).permute(0, 2, 1))
|
|
return torch.bmm(att_scores, inputs), att_scores
|