import torch import torch.nn as nn import math, numpy as np class HierAttnLstm(nn.Module): def __init__(self, args): super().__init__() self.num_nodes, self.feature_dim, self.output_dim = ( args["num_nodes"], args["feature_dim"], args["output_dim"], ) self.input_window, self.output_window = ( args["input_window"], args["output_window"], ) self.hidden_size, self.num_layers = args["hidden_size"], args["num_layers"] self.natt_hops, self.nfc, self.max_up_len = ( args["natt_hops"], args["nfc"], args["max_up_len"], ) self.input_size = self.num_nodes * self.feature_dim self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.lstm_cells = nn.ModuleList( [nn.LSTMCell(self.input_size, self.hidden_size)] + [ nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in range(self.num_layers - 1) ] ) self.hidden_state_pooling = nn.ModuleList( [SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)] ) self.cell_state_pooling = nn.ModuleList( [SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)] ) self.self_attention = SelfAttention(self.hidden_size, self.natt_hops) self.fc_layer = nn.Sequential( nn.Linear(self.hidden_size * self.natt_hops, self.nfc), nn.ReLU(), nn.Linear(self.nfc, self.num_nodes * self.output_dim), ) def forward(self, batch): src, batch_size = batch.permute(1, 0, 2, 3)[..., :1], batch.shape[0] src = src.reshape(self.input_window, batch_size, -1) outputs = [] for i in range(self.output_window): hidden_states, cell_states = ( [ torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers) ], [ torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers) ], ) bottom_layer_outputs, cell_states_history = ( [], [[] for _ in range(self.num_layers)], ) for t in range(self.input_window): hidden_states[0], cell_states[0] = self.lstm_cells[0]( src[t], (hidden_states[0], cell_states[0]) ) bottom_layer_outputs.append(hidden_states[0]) cell_states_history[0].append(cell_states[0]) bottom_layer_outputs, cell_states_history[0] = ( torch.stack(bottom_layer_outputs, 1), torch.stack(cell_states_history[0], 1), ) for layer in range(1, self.num_layers): layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs layer_outputs, cell_states_history[layer] = [], [] for start, end in self.calculate_stride(layer_inputs.size(1)): segment, cell_segment = ( layer_inputs[:, start:end, :], cell_states_history[layer - 1][:, start:end, :], ) pooled_hidden, pooled_cell = ( self.hidden_state_pooling[layer - 1](segment), self.cell_state_pooling[layer - 1]( torch.cat( [cell_segment, cell_states[layer].unsqueeze(1)], 1 ) ), ) hidden_states[layer], cell_states[layer] = self.lstm_cells[layer]( pooled_hidden, (hidden_states[layer], pooled_cell) ) layer_outputs.append(hidden_states[layer]) cell_states_history[layer].append(cell_states[layer]) layer_outputs, cell_states_history[layer] = ( torch.stack(layer_outputs, 1), torch.stack(cell_states_history[layer], 1), ) attended_features, _ = self.self_attention(layer_outputs) out = self.fc_layer(attended_features.view(batch_size, -1)).view( batch_size, self.num_nodes, self.output_dim ) outputs.append(out.clone()) if i < self.output_window - 1: src = torch.cat((src[1:], out.reshape(batch_size, -1).unsqueeze(0)), 0) return torch.stack(outputs).permute(1, 0, 2, 3) def calculate_stride(self, seq_len): idx = np.linspace( 0, seq_len - 1, num=min(self.max_up_len, math.ceil(math.sqrt(seq_len))) + 3 ).astype(int) return list(zip(np.append(idx, seq_len - 1)[:-1], idx[1:])) class SelfAttentionPooling(nn.Module): def __init__(self, input_dim): super().__init__() self.W = nn.Linear(input_dim, 1) def forward(self, batch_rep): att_w = nn.functional.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze( -1 ) return torch.sum(batch_rep * att_w, dim=1) class SelfAttention(nn.Module): def __init__(self, att_size, att_hops): super().__init__() self.ut_dense = nn.Sequential(nn.Linear(att_size, att_size), nn.Tanh()) self.et_dense, self.softmax = nn.Linear(att_size, att_hops), nn.Softmax(dim=-1) def forward(self, inputs): att_scores = self.softmax(self.et_dense(self.ut_dense(inputs)).permute(0, 2, 1)) return torch.bmm(att_scores, inputs), att_scores