TrafficWheel/model/DSANET/DSANET.py

257 lines
7.5 KiB
Python
Executable File

import os
import logging
import traceback
from collections import OrderedDict
import torch.nn as nn
import torch
import torch.nn.functional as F
# from test_tube import HyperOptArgumentParser
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
# import pytorch_lightning as ptl
# from pytorch_lightning.root_module.root_module import LightningModule
# from dataset import MTSFDataset
from model.DSANET.Layers import EncoderLayer, DecoderLayer
class Single_Global_SelfAttn_Module(nn.Module):
def __init__(
self,
window,
n_multiv,
n_kernels,
w_kernel,
d_k,
d_v,
d_model,
d_inner,
n_layers,
n_head,
drop_prob=0.1,
):
"""
Args:
window (int): the length of the input window size
n_multiv (int): num of univariate time series
n_kernels (int): the num of channels
w_kernel (int): the default is 1
d_k (int): d_model / n_head
d_v (int): d_model / n_head
d_model (int): outputs of dimension
d_inner (int): the inner-layer dimension of Position-wise Feed-Forward Networks
n_layers (int): num of layers in Encoder
n_head (int): num of Multi-head
drop_prob (float): the probability of dropout
"""
super(Single_Global_SelfAttn_Module, self).__init__()
self.window = window
self.w_kernel = w_kernel
self.n_multiv = n_multiv
self.d_model = d_model
self.drop_prob = drop_prob
self.conv2 = nn.Conv2d(1, n_kernels, (window, w_kernel))
self.in_linear = nn.Linear(n_kernels, d_model)
self.out_linear = nn.Linear(d_model, n_kernels)
self.layer_stack = nn.ModuleList(
[
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=drop_prob)
for _ in range(n_layers)
]
)
def forward(self, x, return_attns=False):
x = x.view(-1, self.w_kernel, self.window, self.n_multiv)
x2 = F.relu(self.conv2(x))
x2 = nn.Dropout(p=self.drop_prob)(x2)
x = torch.squeeze(x2, 2)
x = torch.transpose(x, 1, 2)
src_seq = self.in_linear(x)
enc_slf_attn_list = []
enc_output = src_seq
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
if return_attns:
return enc_output, enc_slf_attn_list
enc_output = self.out_linear(enc_output)
return (enc_output,)
class Single_Local_SelfAttn_Module(nn.Module):
def __init__(
self,
window,
local,
n_multiv,
n_kernels,
w_kernel,
d_k,
d_v,
d_model,
d_inner,
n_layers,
n_head,
drop_prob=0.1,
):
"""
Args:
window (int): the length of the input window size
n_multiv (int): num of univariate time series
n_kernels (int): the num of channels
w_kernel (int): the default is 1
d_k (int): d_model / n_head
d_v (int): d_model / n_head
d_model (int): outputs of dimension
d_inner (int): the inner-layer dimension of Position-wise Feed-Forward Networks
n_layers (int): num of layers in Encoder
n_head (int): num of Multi-head
drop_prob (float): the probability of dropout
"""
super(Single_Local_SelfAttn_Module, self).__init__()
self.window = window
self.w_kernel = w_kernel
self.n_multiv = n_multiv
self.d_model = d_model
self.drop_prob = drop_prob
self.conv1 = nn.Conv2d(1, n_kernels, (local, w_kernel))
self.pooling1 = nn.AdaptiveMaxPool2d((1, n_multiv))
self.in_linear = nn.Linear(n_kernels, d_model)
self.out_linear = nn.Linear(d_model, n_kernels)
self.layer_stack = nn.ModuleList(
[
EncoderLayer(d_model, d_inner, n_head, d_k, d_v, dropout=drop_prob)
for _ in range(n_layers)
]
)
def forward(self, x, return_attns=False):
x = x.view(-1, self.w_kernel, self.window, self.n_multiv)
x1 = F.relu(self.conv1(x))
x1 = self.pooling1(x1)
x1 = nn.Dropout(p=self.drop_prob)(x1)
x = torch.squeeze(x1, 2)
x = torch.transpose(x, 1, 2)
src_seq = self.in_linear(x)
enc_slf_attn_list = []
enc_output = src_seq
for enc_layer in self.layer_stack:
enc_output, enc_slf_attn = enc_layer(enc_output)
if return_attns:
enc_slf_attn_list += [enc_slf_attn]
if return_attns:
return enc_output, enc_slf_attn_list
enc_output = self.out_linear(enc_output)
return (enc_output,)
class AR(nn.Module):
def __init__(self, window):
super(AR, self).__init__()
self.linear = nn.Linear(window, 12)
def forward(self, x):
x = torch.transpose(x, 1, 2)
x = self.linear(x)
x = torch.transpose(x, 1, 2)
return x
class DSANet(nn.Module):
def __init__(self, args):
"""
Pass in parsed HyperOptArgumentParser to the model
"""
super(DSANet, self).__init__()
self.hparams = args
self.batch_size = args["batch_size"]
# parameters from dataset
self.window = args["window"]
self.local = args["local"]
self.n_multiv = args["n_multiv"]
self.n_kernels = args["n_kernels"]
self.w_kernel = args["w_kernel"]
# hyperparameters of model
self.d_model = args["d_model"]
self.d_inner = args["d_inner"]
self.n_layers = args["n_layers"]
self.n_head = args["n_head"]
self.d_k = args["d_k"]
self.d_v = args["d_v"]
self.drop_prob = args["drop_prob"]
self.sgsf = Single_Global_SelfAttn_Module(
window=self.window,
n_multiv=self.n_multiv,
n_kernels=self.n_kernels,
w_kernel=self.w_kernel,
d_k=self.d_k,
d_v=self.d_v,
d_model=self.d_model,
d_inner=self.d_inner,
n_layers=self.n_layers,
n_head=self.n_head,
drop_prob=self.drop_prob,
)
self.slsf = Single_Local_SelfAttn_Module(
window=self.window,
local=self.local,
n_multiv=self.n_multiv,
n_kernels=self.n_kernels,
w_kernel=self.w_kernel,
d_k=self.d_k,
d_v=self.d_v,
d_model=self.d_model,
d_inner=self.d_inner,
n_layers=self.n_layers,
n_head=self.n_head,
drop_prob=self.drop_prob,
)
self.ar = AR(window=self.window)
self.W_output1 = nn.Linear(2 * self.n_kernels, 12)
self.dropout = nn.Dropout(p=self.drop_prob)
self.active_func = nn.Tanh()
def forward(self, x):
"""
No special modification required for lightning, define as you normally would
"""
x = x[..., 0]
sgsf_output, *_ = self.sgsf(x)
slsf_output, *_ = self.slsf(x)
sf_output = torch.cat((sgsf_output, slsf_output), 2)
sf_output = self.dropout(sf_output)
sf_output = self.W_output1(sf_output)
sf_output = torch.transpose(sf_output, 1, 2)
ar_output = self.ar(x)
output = sf_output + ar_output
output = output.unsqueeze(dim=-1)
return output