TrafficWheel/model/NLT/HierAttnLstm.py

148 lines
6.4 KiB
Python

import torch
import torch.nn as nn
import math
import numpy as np
class HierAttnLstm(nn.Module):
def __init__(self, args):
super(HierAttnLstm, self).__init__()
# self._scaler = self.data_feature.get('scaler')
self.num_nodes = args['num_nodes']
self.feature_dim = args['feature_dim']
self.output_dim = args['output_dim']
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.input_window = args['input_window']
self.output_window = args['output_window']
self.hidden_size = args['hidden_size']
self.num_layers = args['num_layers']
self.natt_unit = self.hidden_size
self.natt_hops = args['natt_hops']
self.nfc = args['nfc']
self.max_up_len = args['max_up_len']
self.input_size = self.num_nodes * self.feature_dim
self.lstm_cells = nn.ModuleList([
nn.LSTMCell(self.input_size, self.hidden_size)
] + [
nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in
range(self.num_layers - 1)
])
self.hidden_state_pooling = nn.ModuleList([
SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)
])
self.cell_state_pooling = nn.ModuleList([
SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)
])
self.self_attention = SelfAttention(self.natt_unit, self.natt_hops)
self.fc_layer = nn.Sequential(
nn.Linear(self.hidden_size * self.natt_hops, self.nfc),
nn.ReLU(),
nn.Linear(self.nfc, self.num_nodes * self.output_dim)
)
def forward(self, batch):
src = batch
# src = batch['X'].clone() # [batch_size, input_window, num_nodes, feature_dim]
src = src.permute(1, 0, 2, 3) # [input_window, batch_size, num_nodes, feature_dim]
# print("src shape: ", src.shape)
src = src[..., 0:1]
batch_size = src.shape[1]
src = src.reshape(self.input_window, batch_size, self.num_nodes * self.feature_dim)
outputs = []
for i in range(self.output_window):
hidden_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)]
cell_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)]
bottom_layer_outputs = []
cell_states_history = [[] for _ in range(self.num_layers)]
for t in range(self.input_window):
hidden_states[0], cell_states[0] = self.lstm_cells[0](src[t], (hidden_states[0], cell_states[0]))
bottom_layer_outputs.append(hidden_states[0])
cell_states_history[0].append(cell_states[0])
bottom_layer_outputs = torch.stack(bottom_layer_outputs, dim=1)
cell_states_history[0] = torch.stack(cell_states_history[0], dim=1)
for layer in range(1, self.num_layers):
layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs
layer_outputs = []
cell_states_history[layer] = []
layer_strides = self.calculate_stride(layer_inputs.size(1))
for start, end in layer_strides:
segment = layer_inputs[:, start:end, :]
cell_segment = cell_states_history[layer - 1][:, start:end, :]
pooled_hidden = self.hidden_state_pooling[layer - 1](segment)
pooled_cell = self.cell_state_pooling[layer - 1](
torch.cat([cell_segment, cell_states[layer].unsqueeze(1)], dim=1))
hidden_states[layer], cell_states[layer] = self.lstm_cells[layer](pooled_hidden, (
hidden_states[layer], pooled_cell))
layer_outputs.append(hidden_states[layer])
cell_states_history[layer].append(cell_states[layer])
layer_outputs = torch.stack(layer_outputs, dim=1)
cell_states_history[layer] = torch.stack(cell_states_history[layer], dim=1)
# print("layer_outputs shape: ", layer_outputs.shape) # [batch, sequence, hidden_size]
attended_features, _ = self.self_attention(layer_outputs)
flattened = attended_features.view(batch_size, -1)
out = self.fc_layer(flattened)
out = out.view(batch_size, self.num_nodes, self.output_dim)
outputs.append(out.clone())
if i < self.output_window - 1:
src = torch.cat(
(src[1:, :, :], out.reshape(batch_size, self.num_nodes * self.feature_dim).unsqueeze(0)), dim=0)
outputs = torch.stack(outputs)
# outputs = [output_window, batch_size, num_nodes, output_dim]
return outputs.permute(1, 0, 2, 3)
def calculate_stride(self, sequence_len):
up_len = min(self.max_up_len, math.ceil(math.sqrt(sequence_len)))
idx = np.linspace(0, sequence_len - 1, num=up_len + 3).astype(int)
if idx[-1] != sequence_len - 1:
idx = np.append(idx, sequence_len - 1)
strides = list(zip(idx[:-1], idx[1:]))
return strides
class SelfAttentionPooling(nn.Module):
def __init__(self, input_dim):
super(SelfAttentionPooling, self).__init__()
self.W = nn.Linear(input_dim, 1)
def forward(self, batch_rep):
softmax = nn.functional.softmax
att_w = softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1)
utter_rep = torch.sum(batch_rep * att_w, dim=1)
return utter_rep
class SelfAttention(nn.Module):
def __init__(self, attention_size, att_hops):
super(SelfAttention, self).__init__()
self.ut_dense = nn.Sequential(
nn.Linear(attention_size, attention_size),
nn.Tanh()
)
self.et_dense = nn.Linear(attention_size, att_hops)
self.softmax = nn.Softmax(dim=-1)
def forward(self, inputs):
# inputs is a 3D Tensor: batch, len, hidden_size
# scores is a 2D Tensor: batch, len
ut = self.ut_dense(inputs)
# et shape: [batch_size, seq_len, att_hops]
et = self.et_dense(ut)
att_scores = self.softmax(torch.permute(et, (0, 2, 1)))
output = torch.bmm(att_scores, inputs)
return output, att_scores