diff --git a/.gitignore b/.gitignore index 46aa946..a68924d 100644 --- a/.gitignore +++ b/.gitignore @@ -6,6 +6,7 @@ experiments/ *.npz *.pkl data/ +pretrain/ # ---> Python # Byte-compiled / optimized / DLL files diff --git a/model/DDGCRN/DDGCRN.py b/model/DDGCRN/DDGCRN.py index 368b3dc..cdf1cc9 100644 --- a/model/DDGCRN/DDGCRN.py +++ b/model/DDGCRN/DDGCRN.py @@ -116,4 +116,3 @@ class DGCN(nn.Module): D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5)) return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul( torch.matmul(D_inv, graph + I), D_inv) - \ No newline at end of file diff --git a/model/GWN/GraphWaveNet.py b/model/GWN/GraphWaveNet.py index 50438f0..d565f71 100644 --- a/model/GWN/GraphWaveNet.py +++ b/model/GWN/GraphWaveNet.py @@ -1,23 +1,14 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch.autograd import Variable -import sys +import torch, torch.nn as nn, torch.nn.functional as F class nconv(nn.Module): - def __init__(self): - super(nconv, self).__init__() - - def forward(self, x, A): - x = torch.einsum('ncvl,vw->ncwl', (x, A)) - return x.contiguous() + def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous() class linear(nn.Module): def __init__(self, c_in, c_out): - super(linear, self).__init__() - self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True) + super().__init__() + self.mlp = nn.Conv2d(c_in, c_out, 1) def forward(self, x): return self.mlp(x) @@ -25,191 +16,86 @@ class linear(nn.Module): class gcn(nn.Module): def __init__(self, c_in, c_out, dropout, support_len=3, order=2): - super(gcn, self).__init__() + super().__init__() self.nconv = nconv() c_in = (order * support_len + 1) * c_in - self.mlp = linear(c_in, c_out) - self.dropout = dropout - self.order = order + self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order def forward(self, x, support): out = [x] for a in support: x1 = self.nconv(x, a) out.append(x1) - for k in range(2, self.order + 1): - x2 = self.nconv(x1, a) - out.append(x2) - x1 = x2 - - h = torch.cat(out, dim=1) - h = self.mlp(h) - h = F.dropout(h, self.dropout, training=self.training) - return h + for _ in range(2, self.order + 1): + x1 = self.nconv(x1, a) + out.append(x1) + return F.dropout(self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training) class gwnet(nn.Module): def __init__(self, args): - super(gwnet, self).__init__() - self.dropout = args['dropout'] - self.blocks = args['blocks'] - self.layers = args['layers'] - self.gcn_bool = args['gcn_bool'] - self.addaptadj = args['addaptadj'] - - self.filter_convs = nn.ModuleList() - self.gate_convs = nn.ModuleList() - self.residual_convs = nn.ModuleList() - self.skip_convs = nn.ModuleList() - self.bn = nn.ModuleList() - self.gconv = nn.ModuleList() - - self.start_conv = nn.Conv2d(in_channels=args['in_dim'], - out_channels=args['residual_channels'], - kernel_size=(1, 1)) + super().__init__() + self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers'] + self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj'] + self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList() + self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList() + self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1) self.supports = args.get('supports', None) - receptive_field = 1 - - self.supports_len = 0 - if self.supports is not None: - self.supports_len += len(self.supports) - + self.supports_len = len(self.supports) if self.supports is not None else 0 if self.gcn_bool and self.addaptadj: aptinit = args.get('aptinit', None) if aptinit is None: - if self.supports is None: - self.supports = [] - self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10).to(args['device']), - requires_grad=True).to(args['device']) - self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes']).to(args['device']), - requires_grad=True).to(args['device']) + if self.supports is None: self.supports = [] + self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10, device=args['device'])) + self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes'], device=args['device'])) self.supports_len += 1 else: - if self.supports is None: - self.supports = [] + if self.supports is None: self.supports = [] m, p, n = torch.svd(aptinit) initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) - self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(args['device']) - self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(args['device']) + self.nodevec1 = nn.Parameter(initemb1) + self.nodevec2 = nn.Parameter(initemb2) self.supports_len += 1 - - kernel_size = args['kernel_size'] - residual_channels = args['residual_channels'] - dilation_channels = args['dilation_channels'] - kernel_size = args['kernel_size'] - skip_channels = args['skip_channels'] - end_channels = args['end_channels'] - out_dim = args['out_dim'] - dropout = args['dropout'] - - + ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \ + args['skip_channels'], args['end_channels'], args['out_dim'] for b in range(self.blocks): - additional_scope = kernel_size - 1 - new_dilation = 1 + add_scope, new_dil = ks - 1, 1 for i in range(self.layers): - # dilated convolutions - self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - - self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, - out_channels=dilation_channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - - # 1x1 convolution for residual connection - self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, - out_channels=residual_channels, - kernel_size=(1, 1))) - - # 1x1 convolution for skip connection - self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, - out_channels=skip_channels, - kernel_size=(1, 1))) - self.bn.append(nn.BatchNorm2d(residual_channels)) - new_dilation *= 2 - receptive_field += additional_scope - additional_scope *= 2 - if self.gcn_bool: - self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len)) - - self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, - out_channels=end_channels, - kernel_size=(1, 1), - bias=True) - - self.end_conv_2 = nn.Conv2d(in_channels=end_channels, - out_channels=out_dim, - kernel_size=(1, 1), - bias=True) - + self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil)) + self.residual_convs.append(nn.Conv2d(dil, res, 1)) + self.skip_convs.append(nn.Conv2d(dil, skip, 1)) + self.bn.append(nn.BatchNorm2d(res)) + new_dil *= 2 + receptive_field += add_scope + add_scope *= 2 + if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len)) + self.end_conv_1 = nn.Conv2d(skip, endc, 1) + self.end_conv_2 = nn.Conv2d(endc, out_dim, 1) self.receptive_field = receptive_field def forward(self, input): - input = input[..., 0:2] - input = input.transpose(1,3) - input = nn.functional.pad(input,(1,0,0,0)) + input = input[..., 0:2].transpose(1, 3) + input = F.pad(input, (1, 0, 0, 0)) in_len = input.size(3) - if in_len < self.receptive_field: - x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) - else: - x = input - x = self.start_conv(x) - skip = 0 - - # calculate the current adaptive adj matrix once per iteration - new_supports = None + x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input + x, skip, new_supports = self.start_conv(x), 0, None if self.gcn_bool and self.addaptadj and self.supports is not None: adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) new_supports = self.supports + [adp] - - # WaveNet layers for i in range(self.blocks * self.layers): - - # |----------------------------------------| *residual* - # | | - # | |-- conv -- tanh --| | - # -> dilate -|----| * ----|-- 1x1 -- + --> *input* - # |-- conv -- sigm --| | - # 1x1 - # | - # ---------------------------------------> + -------------> *skip* - - # (dilation, init_dilation) = self.dilations[i] - - # residual = dilation_func(x, dilation, init_dilation, i) residual = x - # dilated convolution - filter = self.filter_convs[i](residual) - filter = torch.tanh(filter) - gate = self.gate_convs[i](residual) - gate = torch.sigmoid(gate) - x = filter * gate - - # parametrized skip connection - - s = x - s = self.skip_convs[i](s) - try: - skip = skip[:, :, :, -s.size(3):] - except: - skip = 0 - skip = s + skip - + f = self.filter_convs[i](residual).tanh() + g = self.gate_convs[i](residual).sigmoid() + x = f * g + s = self.skip_convs[i](x) + skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s if self.gcn_bool and self.supports is not None: - if self.addaptadj: - x = self.gconv[i](x, new_supports) - else: - x = self.gconv[i](x, self.supports) + x = self.gconv[i](x, new_supports if self.addaptadj else self.supports) else: x = self.residual_convs[i](x) - x = x + residual[:, :, :, -x.size(3):] - x = self.bn[i](x) - - x = F.relu(skip) - x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) - return x + return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip)))) diff --git a/model/GWN/GraphWaveNet_bk.py b/model/GWN/GraphWaveNet_bk.py new file mode 100644 index 0000000..50438f0 --- /dev/null +++ b/model/GWN/GraphWaveNet_bk.py @@ -0,0 +1,215 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.autograd import Variable +import sys + + +class nconv(nn.Module): + def __init__(self): + super(nconv, self).__init__() + + def forward(self, x, A): + x = torch.einsum('ncvl,vw->ncwl', (x, A)) + return x.contiguous() + + +class linear(nn.Module): + def __init__(self, c_in, c_out): + super(linear, self).__init__() + self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True) + + def forward(self, x): + return self.mlp(x) + + +class gcn(nn.Module): + def __init__(self, c_in, c_out, dropout, support_len=3, order=2): + super(gcn, self).__init__() + self.nconv = nconv() + c_in = (order * support_len + 1) * c_in + self.mlp = linear(c_in, c_out) + self.dropout = dropout + self.order = order + + def forward(self, x, support): + out = [x] + for a in support: + x1 = self.nconv(x, a) + out.append(x1) + for k in range(2, self.order + 1): + x2 = self.nconv(x1, a) + out.append(x2) + x1 = x2 + + h = torch.cat(out, dim=1) + h = self.mlp(h) + h = F.dropout(h, self.dropout, training=self.training) + return h + + +class gwnet(nn.Module): + def __init__(self, args): + super(gwnet, self).__init__() + self.dropout = args['dropout'] + self.blocks = args['blocks'] + self.layers = args['layers'] + self.gcn_bool = args['gcn_bool'] + self.addaptadj = args['addaptadj'] + + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.bn = nn.ModuleList() + self.gconv = nn.ModuleList() + + self.start_conv = nn.Conv2d(in_channels=args['in_dim'], + out_channels=args['residual_channels'], + kernel_size=(1, 1)) + self.supports = args.get('supports', None) + + receptive_field = 1 + + self.supports_len = 0 + if self.supports is not None: + self.supports_len += len(self.supports) + + if self.gcn_bool and self.addaptadj: + aptinit = args.get('aptinit', None) + if aptinit is None: + if self.supports is None: + self.supports = [] + self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10).to(args['device']), + requires_grad=True).to(args['device']) + self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes']).to(args['device']), + requires_grad=True).to(args['device']) + self.supports_len += 1 + else: + if self.supports is None: + self.supports = [] + m, p, n = torch.svd(aptinit) + initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5)) + initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t()) + self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(args['device']) + self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(args['device']) + self.supports_len += 1 + + kernel_size = args['kernel_size'] + residual_channels = args['residual_channels'] + dilation_channels = args['dilation_channels'] + kernel_size = args['kernel_size'] + skip_channels = args['skip_channels'] + end_channels = args['end_channels'] + out_dim = args['out_dim'] + dropout = args['dropout'] + + + for b in range(self.blocks): + additional_scope = kernel_size - 1 + new_dilation = 1 + for i in range(self.layers): + # dilated convolutions + self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, + out_channels=dilation_channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + + self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, + out_channels=dilation_channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + + # 1x1 convolution for residual connection + self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, + out_channels=residual_channels, + kernel_size=(1, 1))) + + # 1x1 convolution for skip connection + self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, + out_channels=skip_channels, + kernel_size=(1, 1))) + self.bn.append(nn.BatchNorm2d(residual_channels)) + new_dilation *= 2 + receptive_field += additional_scope + additional_scope *= 2 + if self.gcn_bool: + self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len)) + + self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, + out_channels=end_channels, + kernel_size=(1, 1), + bias=True) + + self.end_conv_2 = nn.Conv2d(in_channels=end_channels, + out_channels=out_dim, + kernel_size=(1, 1), + bias=True) + + self.receptive_field = receptive_field + + def forward(self, input): + input = input[..., 0:2] + input = input.transpose(1,3) + input = nn.functional.pad(input,(1,0,0,0)) + in_len = input.size(3) + if in_len < self.receptive_field: + x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0)) + else: + x = input + x = self.start_conv(x) + skip = 0 + + # calculate the current adaptive adj matrix once per iteration + new_supports = None + if self.gcn_bool and self.addaptadj and self.supports is not None: + adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1) + new_supports = self.supports + [adp] + + # WaveNet layers + for i in range(self.blocks * self.layers): + + # |----------------------------------------| *residual* + # | | + # | |-- conv -- tanh --| | + # -> dilate -|----| * ----|-- 1x1 -- + --> *input* + # |-- conv -- sigm --| | + # 1x1 + # | + # ---------------------------------------> + -------------> *skip* + + # (dilation, init_dilation) = self.dilations[i] + + # residual = dilation_func(x, dilation, init_dilation, i) + residual = x + # dilated convolution + filter = self.filter_convs[i](residual) + filter = torch.tanh(filter) + gate = self.gate_convs[i](residual) + gate = torch.sigmoid(gate) + x = filter * gate + + # parametrized skip connection + + s = x + s = self.skip_convs[i](s) + try: + skip = skip[:, :, :, -s.size(3):] + except: + skip = 0 + skip = s + skip + + if self.gcn_bool and self.supports is not None: + if self.addaptadj: + x = self.gconv[i](x, new_supports) + else: + x = self.gconv[i](x, self.supports) + else: + x = self.residual_convs[i](x) + + x = x + residual[:, :, :, -x.size(3):] + + x = self.bn[i](x) + + x = F.relu(skip) + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) + return x diff --git a/model/NLT/HierAttnLstm.py b/model/NLT/HierAttnLstm.py index 9d3640a..503f305 100644 --- a/model/NLT/HierAttnLstm.py +++ b/model/NLT/HierAttnLstm.py @@ -1,147 +1,95 @@ import torch import torch.nn as nn -import math -import numpy as np +import math, numpy as np class HierAttnLstm(nn.Module): def __init__(self, args): - super(HierAttnLstm, self).__init__() - # self._scaler = self.data_feature.get('scaler') - self.num_nodes = args['num_nodes'] - self.feature_dim = args['feature_dim'] - self.output_dim = args['output_dim'] + super().__init__() + self.num_nodes, self.feature_dim, self.output_dim = args['num_nodes'], args['feature_dim'], args['output_dim'] + self.input_window, self.output_window = args['input_window'], args['output_window'] + self.hidden_size, self.num_layers = args['hidden_size'], args['num_layers'] + self.natt_hops, self.nfc, self.max_up_len = args['natt_hops'], args['nfc'], args['max_up_len'] + self.input_size = self.num_nodes * self.feature_dim self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.input_window = args['input_window'] - self.output_window = args['output_window'] - - self.hidden_size = args['hidden_size'] - self.num_layers = args['num_layers'] - self.natt_unit = self.hidden_size - self.natt_hops = args['natt_hops'] - self.nfc = args['nfc'] - self.max_up_len = args['max_up_len'] - - self.input_size = self.num_nodes * self.feature_dim - - self.lstm_cells = nn.ModuleList([ - nn.LSTMCell(self.input_size, self.hidden_size) - ] + [ - nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in - range(self.num_layers - 1) - ]) - self.hidden_state_pooling = nn.ModuleList([ - SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1) - ]) - self.cell_state_pooling = nn.ModuleList([ - SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1) - ]) - self.self_attention = SelfAttention(self.natt_unit, self.natt_hops) + self.lstm_cells = nn.ModuleList([nn.LSTMCell(self.input_size, self.hidden_size)] + + [nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in + range(self.num_layers - 1)]) + self.hidden_state_pooling = nn.ModuleList( + [SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)]) + self.cell_state_pooling = nn.ModuleList( + [SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)]) + self.self_attention = SelfAttention(self.hidden_size, self.natt_hops) self.fc_layer = nn.Sequential( - nn.Linear(self.hidden_size * self.natt_hops, self.nfc), - nn.ReLU(), - nn.Linear(self.nfc, self.num_nodes * self.output_dim) - ) + nn.Linear(self.hidden_size * self.natt_hops, self.nfc), nn.ReLU(), + nn.Linear(self.nfc, self.num_nodes * self.output_dim)) def forward(self, batch): - src = batch - # src = batch['X'].clone() # [batch_size, input_window, num_nodes, feature_dim] - src = src.permute(1, 0, 2, 3) # [input_window, batch_size, num_nodes, feature_dim] - # print("src shape: ", src.shape) - src = src[..., 0:1] - batch_size = src.shape[1] - src = src.reshape(self.input_window, batch_size, self.num_nodes * self.feature_dim) + src, batch_size = batch.permute(1, 0, 2, 3)[..., :1], batch.shape[0] + src = src.reshape(self.input_window, batch_size, -1) outputs = [] for i in range(self.output_window): - hidden_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)] - cell_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)] + hidden_states, cell_states = [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in + range(self.num_layers)], \ + [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)] + bottom_layer_outputs, cell_states_history = [], [[] for _ in range(self.num_layers)] - bottom_layer_outputs = [] - cell_states_history = [[] for _ in range(self.num_layers)] for t in range(self.input_window): hidden_states[0], cell_states[0] = self.lstm_cells[0](src[t], (hidden_states[0], cell_states[0])) bottom_layer_outputs.append(hidden_states[0]) cell_states_history[0].append(cell_states[0]) - bottom_layer_outputs = torch.stack(bottom_layer_outputs, dim=1) - cell_states_history[0] = torch.stack(cell_states_history[0], dim=1) + bottom_layer_outputs, cell_states_history[0] = torch.stack(bottom_layer_outputs, 1), torch.stack( + cell_states_history[0], 1) for layer in range(1, self.num_layers): layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs - layer_outputs = [] - cell_states_history[layer] = [] - layer_strides = self.calculate_stride(layer_inputs.size(1)) - - for start, end in layer_strides: - segment = layer_inputs[:, start:end, :] - cell_segment = cell_states_history[layer - 1][:, start:end, :] - - pooled_hidden = self.hidden_state_pooling[layer - 1](segment) - pooled_cell = self.cell_state_pooling[layer - 1]( - torch.cat([cell_segment, cell_states[layer].unsqueeze(1)], dim=1)) + layer_outputs, cell_states_history[layer] = [], [] + for start, end in self.calculate_stride(layer_inputs.size(1)): + segment, cell_segment = layer_inputs[:, start:end, :], cell_states_history[layer - 1][:, start:end, + :] + pooled_hidden, pooled_cell = self.hidden_state_pooling[layer - 1](segment), self.cell_state_pooling[ + layer - 1](torch.cat([cell_segment, cell_states[layer].unsqueeze(1)], 1)) hidden_states[layer], cell_states[layer] = self.lstm_cells[layer](pooled_hidden, ( - hidden_states[layer], pooled_cell)) + hidden_states[layer], pooled_cell)) layer_outputs.append(hidden_states[layer]) cell_states_history[layer].append(cell_states[layer]) - layer_outputs = torch.stack(layer_outputs, dim=1) - cell_states_history[layer] = torch.stack(cell_states_history[layer], dim=1) - - # print("layer_outputs shape: ", layer_outputs.shape) # [batch, sequence, hidden_size] + layer_outputs, cell_states_history[layer] = torch.stack(layer_outputs, 1), torch.stack( + cell_states_history[layer], 1) attended_features, _ = self.self_attention(layer_outputs) - flattened = attended_features.view(batch_size, -1) - out = self.fc_layer(flattened) - out = out.view(batch_size, self.num_nodes, self.output_dim) + out = self.fc_layer(attended_features.view(batch_size, -1)).view(batch_size, self.num_nodes, + self.output_dim) outputs.append(out.clone()) - if i < self.output_window - 1: - src = torch.cat( - (src[1:, :, :], out.reshape(batch_size, self.num_nodes * self.feature_dim).unsqueeze(0)), dim=0) + src = torch.cat((src[1:], out.reshape(batch_size, -1).unsqueeze(0)), 0) - outputs = torch.stack(outputs) - # outputs = [output_window, batch_size, num_nodes, output_dim] - return outputs.permute(1, 0, 2, 3) + return torch.stack(outputs).permute(1, 0, 2, 3) - def calculate_stride(self, sequence_len): - up_len = min(self.max_up_len, math.ceil(math.sqrt(sequence_len))) - idx = np.linspace(0, sequence_len - 1, num=up_len + 3).astype(int) - if idx[-1] != sequence_len - 1: - idx = np.append(idx, sequence_len - 1) - strides = list(zip(idx[:-1], idx[1:])) - return strides + def calculate_stride(self, seq_len): + idx = np.linspace(0, seq_len - 1, num=min(self.max_up_len, math.ceil(math.sqrt(seq_len))) + 3).astype(int) + return list(zip(np.append(idx, seq_len - 1)[:-1], idx[1:])) class SelfAttentionPooling(nn.Module): def __init__(self, input_dim): - super(SelfAttentionPooling, self).__init__() + super().__init__() self.W = nn.Linear(input_dim, 1) def forward(self, batch_rep): - softmax = nn.functional.softmax - att_w = softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1) - utter_rep = torch.sum(batch_rep * att_w, dim=1) - return utter_rep + att_w = nn.functional.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1) + return torch.sum(batch_rep * att_w, dim=1) class SelfAttention(nn.Module): - def __init__(self, attention_size, att_hops): - super(SelfAttention, self).__init__() - self.ut_dense = nn.Sequential( - nn.Linear(attention_size, attention_size), - nn.Tanh() - ) - self.et_dense = nn.Linear(attention_size, att_hops) - self.softmax = nn.Softmax(dim=-1) + def __init__(self, att_size, att_hops): + super().__init__() + self.ut_dense = nn.Sequential(nn.Linear(att_size, att_size), nn.Tanh()) + self.et_dense, self.softmax = nn.Linear(att_size, att_hops), nn.Softmax(dim=-1) def forward(self, inputs): - # inputs is a 3D Tensor: batch, len, hidden_size - # scores is a 2D Tensor: batch, len - ut = self.ut_dense(inputs) - # et shape: [batch_size, seq_len, att_hops] - et = self.et_dense(ut) - att_scores = self.softmax(torch.permute(et, (0, 2, 1))) - output = torch.bmm(att_scores, inputs) - return output, att_scores + att_scores = self.softmax(self.et_dense(self.ut_dense(inputs)).permute(0, 2, 1)) + return torch.bmm(att_scores, inputs), att_scores diff --git a/model/model_selector.py b/model/model_selector.py index f85c73e..54b8c5f 100644 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -14,6 +14,7 @@ from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.EXP.EXP import EXP +from model.EXPB.EXP_b import EXPB def model_selector(model): match model['type']: @@ -33,4 +34,5 @@ def model_selector(model): case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) case 'EXP': return EXP(model) + case 'EXPB': return EXPB(model)