简化NLT, DDGCRN, GWN代码
This commit is contained in:
parent
1b76cc6ce2
commit
e8fc67b867
|
|
@ -6,6 +6,7 @@ experiments/
|
||||||
*.npz
|
*.npz
|
||||||
*.pkl
|
*.pkl
|
||||||
data/
|
data/
|
||||||
|
pretrain/
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|
|
||||||
|
|
@ -116,4 +116,3 @@ class DGCN(nn.Module):
|
||||||
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
D_inv = torch.diag_embed(torch.sum(graph, -1) ** (-0.5))
|
||||||
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
|
return torch.matmul(torch.matmul(D_inv, graph), D_inv) if normalize else torch.matmul(
|
||||||
torch.matmul(D_inv, graph + I), D_inv)
|
torch.matmul(D_inv, graph + I), D_inv)
|
||||||
|
|
||||||
|
|
@ -1,23 +1,14 @@
|
||||||
import torch
|
import torch, torch.nn as nn, torch.nn.functional as F
|
||||||
import torch.nn as nn
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.autograd import Variable
|
|
||||||
import sys
|
|
||||||
|
|
||||||
|
|
||||||
class nconv(nn.Module):
|
class nconv(nn.Module):
|
||||||
def __init__(self):
|
def forward(self, x, A): return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
|
||||||
super(nconv, self).__init__()
|
|
||||||
|
|
||||||
def forward(self, x, A):
|
|
||||||
x = torch.einsum('ncvl,vw->ncwl', (x, A))
|
|
||||||
return x.contiguous()
|
|
||||||
|
|
||||||
|
|
||||||
class linear(nn.Module):
|
class linear(nn.Module):
|
||||||
def __init__(self, c_in, c_out):
|
def __init__(self, c_in, c_out):
|
||||||
super(linear, self).__init__()
|
super().__init__()
|
||||||
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
self.mlp = nn.Conv2d(c_in, c_out, 1)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
return self.mlp(x)
|
return self.mlp(x)
|
||||||
|
|
@ -25,191 +16,86 @@ class linear(nn.Module):
|
||||||
|
|
||||||
class gcn(nn.Module):
|
class gcn(nn.Module):
|
||||||
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||||
super(gcn, self).__init__()
|
super().__init__()
|
||||||
self.nconv = nconv()
|
self.nconv = nconv()
|
||||||
c_in = (order * support_len + 1) * c_in
|
c_in = (order * support_len + 1) * c_in
|
||||||
self.mlp = linear(c_in, c_out)
|
self.mlp, self.dropout, self.order = linear(c_in, c_out), dropout, order
|
||||||
self.dropout = dropout
|
|
||||||
self.order = order
|
|
||||||
|
|
||||||
def forward(self, x, support):
|
def forward(self, x, support):
|
||||||
out = [x]
|
out = [x]
|
||||||
for a in support:
|
for a in support:
|
||||||
x1 = self.nconv(x, a)
|
x1 = self.nconv(x, a)
|
||||||
out.append(x1)
|
out.append(x1)
|
||||||
for k in range(2, self.order + 1):
|
for _ in range(2, self.order + 1):
|
||||||
x2 = self.nconv(x1, a)
|
x1 = self.nconv(x1, a)
|
||||||
out.append(x2)
|
out.append(x1)
|
||||||
x1 = x2
|
return F.dropout(self.mlp(torch.cat(out, dim=1)), self.dropout, training=self.training)
|
||||||
|
|
||||||
h = torch.cat(out, dim=1)
|
|
||||||
h = self.mlp(h)
|
|
||||||
h = F.dropout(h, self.dropout, training=self.training)
|
|
||||||
return h
|
|
||||||
|
|
||||||
|
|
||||||
class gwnet(nn.Module):
|
class gwnet(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(gwnet, self).__init__()
|
super().__init__()
|
||||||
self.dropout = args['dropout']
|
self.dropout, self.blocks, self.layers = args['dropout'], args['blocks'], args['layers']
|
||||||
self.blocks = args['blocks']
|
self.gcn_bool, self.addaptadj = args['gcn_bool'], args['addaptadj']
|
||||||
self.layers = args['layers']
|
self.filter_convs, self.gate_convs = nn.ModuleList(), nn.ModuleList()
|
||||||
self.gcn_bool = args['gcn_bool']
|
self.residual_convs, self.skip_convs, self.bn, self.gconv = nn.ModuleList(), nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
|
||||||
self.addaptadj = args['addaptadj']
|
self.start_conv = nn.Conv2d(args['in_dim'], args['residual_channels'], 1)
|
||||||
|
|
||||||
self.filter_convs = nn.ModuleList()
|
|
||||||
self.gate_convs = nn.ModuleList()
|
|
||||||
self.residual_convs = nn.ModuleList()
|
|
||||||
self.skip_convs = nn.ModuleList()
|
|
||||||
self.bn = nn.ModuleList()
|
|
||||||
self.gconv = nn.ModuleList()
|
|
||||||
|
|
||||||
self.start_conv = nn.Conv2d(in_channels=args['in_dim'],
|
|
||||||
out_channels=args['residual_channels'],
|
|
||||||
kernel_size=(1, 1))
|
|
||||||
self.supports = args.get('supports', None)
|
self.supports = args.get('supports', None)
|
||||||
|
|
||||||
receptive_field = 1
|
receptive_field = 1
|
||||||
|
self.supports_len = len(self.supports) if self.supports is not None else 0
|
||||||
self.supports_len = 0
|
|
||||||
if self.supports is not None:
|
|
||||||
self.supports_len += len(self.supports)
|
|
||||||
|
|
||||||
if self.gcn_bool and self.addaptadj:
|
if self.gcn_bool and self.addaptadj:
|
||||||
aptinit = args.get('aptinit', None)
|
aptinit = args.get('aptinit', None)
|
||||||
if aptinit is None:
|
if aptinit is None:
|
||||||
if self.supports is None:
|
if self.supports is None: self.supports = []
|
||||||
self.supports = []
|
self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10, device=args['device']))
|
||||||
self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10).to(args['device']),
|
self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes'], device=args['device']))
|
||||||
requires_grad=True).to(args['device'])
|
|
||||||
self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes']).to(args['device']),
|
|
||||||
requires_grad=True).to(args['device'])
|
|
||||||
self.supports_len += 1
|
self.supports_len += 1
|
||||||
else:
|
else:
|
||||||
if self.supports is None:
|
if self.supports is None: self.supports = []
|
||||||
self.supports = []
|
|
||||||
m, p, n = torch.svd(aptinit)
|
m, p, n = torch.svd(aptinit)
|
||||||
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||||||
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||||||
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(args['device'])
|
self.nodevec1 = nn.Parameter(initemb1)
|
||||||
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(args['device'])
|
self.nodevec2 = nn.Parameter(initemb2)
|
||||||
self.supports_len += 1
|
self.supports_len += 1
|
||||||
|
ks, res, dil, skip, endc, out_dim = args['kernel_size'], args['residual_channels'], args['dilation_channels'], \
|
||||||
kernel_size = args['kernel_size']
|
args['skip_channels'], args['end_channels'], args['out_dim']
|
||||||
residual_channels = args['residual_channels']
|
|
||||||
dilation_channels = args['dilation_channels']
|
|
||||||
kernel_size = args['kernel_size']
|
|
||||||
skip_channels = args['skip_channels']
|
|
||||||
end_channels = args['end_channels']
|
|
||||||
out_dim = args['out_dim']
|
|
||||||
dropout = args['dropout']
|
|
||||||
|
|
||||||
|
|
||||||
for b in range(self.blocks):
|
for b in range(self.blocks):
|
||||||
additional_scope = kernel_size - 1
|
add_scope, new_dil = ks - 1, 1
|
||||||
new_dilation = 1
|
|
||||||
for i in range(self.layers):
|
for i in range(self.layers):
|
||||||
# dilated convolutions
|
self.filter_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||||
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels,
|
self.gate_convs.append(nn.Conv2d(res, dil, (1, ks), dilation=new_dil))
|
||||||
out_channels=dilation_channels,
|
self.residual_convs.append(nn.Conv2d(dil, res, 1))
|
||||||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
self.skip_convs.append(nn.Conv2d(dil, skip, 1))
|
||||||
|
self.bn.append(nn.BatchNorm2d(res))
|
||||||
self.gate_convs.append(nn.Conv2d(in_channels=residual_channels,
|
new_dil *= 2
|
||||||
out_channels=dilation_channels,
|
receptive_field += add_scope
|
||||||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
add_scope *= 2
|
||||||
|
if self.gcn_bool: self.gconv.append(gcn(dil, res, args['dropout'], support_len=self.supports_len))
|
||||||
# 1x1 convolution for residual connection
|
self.end_conv_1 = nn.Conv2d(skip, endc, 1)
|
||||||
self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels,
|
self.end_conv_2 = nn.Conv2d(endc, out_dim, 1)
|
||||||
out_channels=residual_channels,
|
|
||||||
kernel_size=(1, 1)))
|
|
||||||
|
|
||||||
# 1x1 convolution for skip connection
|
|
||||||
self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels,
|
|
||||||
out_channels=skip_channels,
|
|
||||||
kernel_size=(1, 1)))
|
|
||||||
self.bn.append(nn.BatchNorm2d(residual_channels))
|
|
||||||
new_dilation *= 2
|
|
||||||
receptive_field += additional_scope
|
|
||||||
additional_scope *= 2
|
|
||||||
if self.gcn_bool:
|
|
||||||
self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len))
|
|
||||||
|
|
||||||
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
|
|
||||||
out_channels=end_channels,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
bias=True)
|
|
||||||
|
|
||||||
self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
|
|
||||||
out_channels=out_dim,
|
|
||||||
kernel_size=(1, 1),
|
|
||||||
bias=True)
|
|
||||||
|
|
||||||
self.receptive_field = receptive_field
|
self.receptive_field = receptive_field
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
input = input[..., 0:2]
|
input = input[..., 0:2].transpose(1, 3)
|
||||||
input = input.transpose(1,3)
|
input = F.pad(input, (1, 0, 0, 0))
|
||||||
input = nn.functional.pad(input,(1,0,0,0))
|
|
||||||
in_len = input.size(3)
|
in_len = input.size(3)
|
||||||
if in_len < self.receptive_field:
|
x = F.pad(input, (self.receptive_field - in_len, 0, 0, 0)) if in_len < self.receptive_field else input
|
||||||
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
x, skip, new_supports = self.start_conv(x), 0, None
|
||||||
else:
|
|
||||||
x = input
|
|
||||||
x = self.start_conv(x)
|
|
||||||
skip = 0
|
|
||||||
|
|
||||||
# calculate the current adaptive adj matrix once per iteration
|
|
||||||
new_supports = None
|
|
||||||
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||||||
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||||
new_supports = self.supports + [adp]
|
new_supports = self.supports + [adp]
|
||||||
|
|
||||||
# WaveNet layers
|
|
||||||
for i in range(self.blocks * self.layers):
|
for i in range(self.blocks * self.layers):
|
||||||
|
|
||||||
# |----------------------------------------| *residual*
|
|
||||||
# | |
|
|
||||||
# | |-- conv -- tanh --| |
|
|
||||||
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
|
|
||||||
# |-- conv -- sigm --| |
|
|
||||||
# 1x1
|
|
||||||
# |
|
|
||||||
# ---------------------------------------> + -------------> *skip*
|
|
||||||
|
|
||||||
# (dilation, init_dilation) = self.dilations[i]
|
|
||||||
|
|
||||||
# residual = dilation_func(x, dilation, init_dilation, i)
|
|
||||||
residual = x
|
residual = x
|
||||||
# dilated convolution
|
f = self.filter_convs[i](residual).tanh()
|
||||||
filter = self.filter_convs[i](residual)
|
g = self.gate_convs[i](residual).sigmoid()
|
||||||
filter = torch.tanh(filter)
|
x = f * g
|
||||||
gate = self.gate_convs[i](residual)
|
s = self.skip_convs[i](x)
|
||||||
gate = torch.sigmoid(gate)
|
skip = (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + s
|
||||||
x = filter * gate
|
|
||||||
|
|
||||||
# parametrized skip connection
|
|
||||||
|
|
||||||
s = x
|
|
||||||
s = self.skip_convs[i](s)
|
|
||||||
try:
|
|
||||||
skip = skip[:, :, :, -s.size(3):]
|
|
||||||
except:
|
|
||||||
skip = 0
|
|
||||||
skip = s + skip
|
|
||||||
|
|
||||||
if self.gcn_bool and self.supports is not None:
|
if self.gcn_bool and self.supports is not None:
|
||||||
if self.addaptadj:
|
x = self.gconv[i](x, new_supports if self.addaptadj else self.supports)
|
||||||
x = self.gconv[i](x, new_supports)
|
|
||||||
else:
|
|
||||||
x = self.gconv[i](x, self.supports)
|
|
||||||
else:
|
else:
|
||||||
x = self.residual_convs[i](x)
|
x = self.residual_convs[i](x)
|
||||||
|
|
||||||
x = x + residual[:, :, :, -x.size(3):]
|
x = x + residual[:, :, :, -x.size(3):]
|
||||||
|
|
||||||
x = self.bn[i](x)
|
x = self.bn[i](x)
|
||||||
|
return self.end_conv_2(F.relu(self.end_conv_1(F.relu(skip))))
|
||||||
x = F.relu(skip)
|
|
||||||
x = F.relu(self.end_conv_1(x))
|
|
||||||
x = self.end_conv_2(x)
|
|
||||||
return x
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,215 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.autograd import Variable
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
class nconv(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(nconv, self).__init__()
|
||||||
|
|
||||||
|
def forward(self, x, A):
|
||||||
|
x = torch.einsum('ncvl,vw->ncwl', (x, A))
|
||||||
|
return x.contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class linear(nn.Module):
|
||||||
|
def __init__(self, c_in, c_out):
|
||||||
|
super(linear, self).__init__()
|
||||||
|
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return self.mlp(x)
|
||||||
|
|
||||||
|
|
||||||
|
class gcn(nn.Module):
|
||||||
|
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
||||||
|
super(gcn, self).__init__()
|
||||||
|
self.nconv = nconv()
|
||||||
|
c_in = (order * support_len + 1) * c_in
|
||||||
|
self.mlp = linear(c_in, c_out)
|
||||||
|
self.dropout = dropout
|
||||||
|
self.order = order
|
||||||
|
|
||||||
|
def forward(self, x, support):
|
||||||
|
out = [x]
|
||||||
|
for a in support:
|
||||||
|
x1 = self.nconv(x, a)
|
||||||
|
out.append(x1)
|
||||||
|
for k in range(2, self.order + 1):
|
||||||
|
x2 = self.nconv(x1, a)
|
||||||
|
out.append(x2)
|
||||||
|
x1 = x2
|
||||||
|
|
||||||
|
h = torch.cat(out, dim=1)
|
||||||
|
h = self.mlp(h)
|
||||||
|
h = F.dropout(h, self.dropout, training=self.training)
|
||||||
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
class gwnet(nn.Module):
|
||||||
|
def __init__(self, args):
|
||||||
|
super(gwnet, self).__init__()
|
||||||
|
self.dropout = args['dropout']
|
||||||
|
self.blocks = args['blocks']
|
||||||
|
self.layers = args['layers']
|
||||||
|
self.gcn_bool = args['gcn_bool']
|
||||||
|
self.addaptadj = args['addaptadj']
|
||||||
|
|
||||||
|
self.filter_convs = nn.ModuleList()
|
||||||
|
self.gate_convs = nn.ModuleList()
|
||||||
|
self.residual_convs = nn.ModuleList()
|
||||||
|
self.skip_convs = nn.ModuleList()
|
||||||
|
self.bn = nn.ModuleList()
|
||||||
|
self.gconv = nn.ModuleList()
|
||||||
|
|
||||||
|
self.start_conv = nn.Conv2d(in_channels=args['in_dim'],
|
||||||
|
out_channels=args['residual_channels'],
|
||||||
|
kernel_size=(1, 1))
|
||||||
|
self.supports = args.get('supports', None)
|
||||||
|
|
||||||
|
receptive_field = 1
|
||||||
|
|
||||||
|
self.supports_len = 0
|
||||||
|
if self.supports is not None:
|
||||||
|
self.supports_len += len(self.supports)
|
||||||
|
|
||||||
|
if self.gcn_bool and self.addaptadj:
|
||||||
|
aptinit = args.get('aptinit', None)
|
||||||
|
if aptinit is None:
|
||||||
|
if self.supports is None:
|
||||||
|
self.supports = []
|
||||||
|
self.nodevec1 = nn.Parameter(torch.randn(args['num_nodes'], 10).to(args['device']),
|
||||||
|
requires_grad=True).to(args['device'])
|
||||||
|
self.nodevec2 = nn.Parameter(torch.randn(10, args['num_nodes']).to(args['device']),
|
||||||
|
requires_grad=True).to(args['device'])
|
||||||
|
self.supports_len += 1
|
||||||
|
else:
|
||||||
|
if self.supports is None:
|
||||||
|
self.supports = []
|
||||||
|
m, p, n = torch.svd(aptinit)
|
||||||
|
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
|
||||||
|
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
|
||||||
|
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(args['device'])
|
||||||
|
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(args['device'])
|
||||||
|
self.supports_len += 1
|
||||||
|
|
||||||
|
kernel_size = args['kernel_size']
|
||||||
|
residual_channels = args['residual_channels']
|
||||||
|
dilation_channels = args['dilation_channels']
|
||||||
|
kernel_size = args['kernel_size']
|
||||||
|
skip_channels = args['skip_channels']
|
||||||
|
end_channels = args['end_channels']
|
||||||
|
out_dim = args['out_dim']
|
||||||
|
dropout = args['dropout']
|
||||||
|
|
||||||
|
|
||||||
|
for b in range(self.blocks):
|
||||||
|
additional_scope = kernel_size - 1
|
||||||
|
new_dilation = 1
|
||||||
|
for i in range(self.layers):
|
||||||
|
# dilated convolutions
|
||||||
|
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels,
|
||||||
|
out_channels=dilation_channels,
|
||||||
|
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||||||
|
|
||||||
|
self.gate_convs.append(nn.Conv2d(in_channels=residual_channels,
|
||||||
|
out_channels=dilation_channels,
|
||||||
|
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||||||
|
|
||||||
|
# 1x1 convolution for residual connection
|
||||||
|
self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels,
|
||||||
|
out_channels=residual_channels,
|
||||||
|
kernel_size=(1, 1)))
|
||||||
|
|
||||||
|
# 1x1 convolution for skip connection
|
||||||
|
self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels,
|
||||||
|
out_channels=skip_channels,
|
||||||
|
kernel_size=(1, 1)))
|
||||||
|
self.bn.append(nn.BatchNorm2d(residual_channels))
|
||||||
|
new_dilation *= 2
|
||||||
|
receptive_field += additional_scope
|
||||||
|
additional_scope *= 2
|
||||||
|
if self.gcn_bool:
|
||||||
|
self.gconv.append(gcn(dilation_channels, residual_channels, dropout, support_len=self.supports_len))
|
||||||
|
|
||||||
|
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels,
|
||||||
|
out_channels=end_channels,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.end_conv_2 = nn.Conv2d(in_channels=end_channels,
|
||||||
|
out_channels=out_dim,
|
||||||
|
kernel_size=(1, 1),
|
||||||
|
bias=True)
|
||||||
|
|
||||||
|
self.receptive_field = receptive_field
|
||||||
|
|
||||||
|
def forward(self, input):
|
||||||
|
input = input[..., 0:2]
|
||||||
|
input = input.transpose(1,3)
|
||||||
|
input = nn.functional.pad(input,(1,0,0,0))
|
||||||
|
in_len = input.size(3)
|
||||||
|
if in_len < self.receptive_field:
|
||||||
|
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
|
||||||
|
else:
|
||||||
|
x = input
|
||||||
|
x = self.start_conv(x)
|
||||||
|
skip = 0
|
||||||
|
|
||||||
|
# calculate the current adaptive adj matrix once per iteration
|
||||||
|
new_supports = None
|
||||||
|
if self.gcn_bool and self.addaptadj and self.supports is not None:
|
||||||
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
||||||
|
new_supports = self.supports + [adp]
|
||||||
|
|
||||||
|
# WaveNet layers
|
||||||
|
for i in range(self.blocks * self.layers):
|
||||||
|
|
||||||
|
# |----------------------------------------| *residual*
|
||||||
|
# | |
|
||||||
|
# | |-- conv -- tanh --| |
|
||||||
|
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
|
||||||
|
# |-- conv -- sigm --| |
|
||||||
|
# 1x1
|
||||||
|
# |
|
||||||
|
# ---------------------------------------> + -------------> *skip*
|
||||||
|
|
||||||
|
# (dilation, init_dilation) = self.dilations[i]
|
||||||
|
|
||||||
|
# residual = dilation_func(x, dilation, init_dilation, i)
|
||||||
|
residual = x
|
||||||
|
# dilated convolution
|
||||||
|
filter = self.filter_convs[i](residual)
|
||||||
|
filter = torch.tanh(filter)
|
||||||
|
gate = self.gate_convs[i](residual)
|
||||||
|
gate = torch.sigmoid(gate)
|
||||||
|
x = filter * gate
|
||||||
|
|
||||||
|
# parametrized skip connection
|
||||||
|
|
||||||
|
s = x
|
||||||
|
s = self.skip_convs[i](s)
|
||||||
|
try:
|
||||||
|
skip = skip[:, :, :, -s.size(3):]
|
||||||
|
except:
|
||||||
|
skip = 0
|
||||||
|
skip = s + skip
|
||||||
|
|
||||||
|
if self.gcn_bool and self.supports is not None:
|
||||||
|
if self.addaptadj:
|
||||||
|
x = self.gconv[i](x, new_supports)
|
||||||
|
else:
|
||||||
|
x = self.gconv[i](x, self.supports)
|
||||||
|
else:
|
||||||
|
x = self.residual_convs[i](x)
|
||||||
|
|
||||||
|
x = x + residual[:, :, :, -x.size(3):]
|
||||||
|
|
||||||
|
x = self.bn[i](x)
|
||||||
|
|
||||||
|
x = F.relu(skip)
|
||||||
|
x = F.relu(self.end_conv_1(x))
|
||||||
|
x = self.end_conv_2(x)
|
||||||
|
return x
|
||||||
|
|
@ -1,147 +1,95 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import math
|
import math, numpy as np
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
|
|
||||||
class HierAttnLstm(nn.Module):
|
class HierAttnLstm(nn.Module):
|
||||||
def __init__(self, args):
|
def __init__(self, args):
|
||||||
super(HierAttnLstm, self).__init__()
|
super().__init__()
|
||||||
# self._scaler = self.data_feature.get('scaler')
|
self.num_nodes, self.feature_dim, self.output_dim = args['num_nodes'], args['feature_dim'], args['output_dim']
|
||||||
self.num_nodes = args['num_nodes']
|
self.input_window, self.output_window = args['input_window'], args['output_window']
|
||||||
self.feature_dim = args['feature_dim']
|
self.hidden_size, self.num_layers = args['hidden_size'], args['num_layers']
|
||||||
self.output_dim = args['output_dim']
|
self.natt_hops, self.nfc, self.max_up_len = args['natt_hops'], args['nfc'], args['max_up_len']
|
||||||
|
self.input_size = self.num_nodes * self.feature_dim
|
||||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
self.input_window = args['input_window']
|
self.lstm_cells = nn.ModuleList([nn.LSTMCell(self.input_size, self.hidden_size)] +
|
||||||
self.output_window = args['output_window']
|
[nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in
|
||||||
|
range(self.num_layers - 1)])
|
||||||
self.hidden_size = args['hidden_size']
|
self.hidden_state_pooling = nn.ModuleList(
|
||||||
self.num_layers = args['num_layers']
|
[SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)])
|
||||||
self.natt_unit = self.hidden_size
|
self.cell_state_pooling = nn.ModuleList(
|
||||||
self.natt_hops = args['natt_hops']
|
[SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)])
|
||||||
self.nfc = args['nfc']
|
self.self_attention = SelfAttention(self.hidden_size, self.natt_hops)
|
||||||
self.max_up_len = args['max_up_len']
|
|
||||||
|
|
||||||
self.input_size = self.num_nodes * self.feature_dim
|
|
||||||
|
|
||||||
self.lstm_cells = nn.ModuleList([
|
|
||||||
nn.LSTMCell(self.input_size, self.hidden_size)
|
|
||||||
] + [
|
|
||||||
nn.LSTMCell(self.hidden_size, self.hidden_size) for _ in
|
|
||||||
range(self.num_layers - 1)
|
|
||||||
])
|
|
||||||
self.hidden_state_pooling = nn.ModuleList([
|
|
||||||
SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)
|
|
||||||
])
|
|
||||||
self.cell_state_pooling = nn.ModuleList([
|
|
||||||
SelfAttentionPooling(self.hidden_size) for _ in range(self.num_layers - 1)
|
|
||||||
])
|
|
||||||
self.self_attention = SelfAttention(self.natt_unit, self.natt_hops)
|
|
||||||
self.fc_layer = nn.Sequential(
|
self.fc_layer = nn.Sequential(
|
||||||
nn.Linear(self.hidden_size * self.natt_hops, self.nfc),
|
nn.Linear(self.hidden_size * self.natt_hops, self.nfc), nn.ReLU(),
|
||||||
nn.ReLU(),
|
nn.Linear(self.nfc, self.num_nodes * self.output_dim))
|
||||||
nn.Linear(self.nfc, self.num_nodes * self.output_dim)
|
|
||||||
)
|
|
||||||
|
|
||||||
def forward(self, batch):
|
def forward(self, batch):
|
||||||
src = batch
|
src, batch_size = batch.permute(1, 0, 2, 3)[..., :1], batch.shape[0]
|
||||||
# src = batch['X'].clone() # [batch_size, input_window, num_nodes, feature_dim]
|
src = src.reshape(self.input_window, batch_size, -1)
|
||||||
src = src.permute(1, 0, 2, 3) # [input_window, batch_size, num_nodes, feature_dim]
|
|
||||||
# print("src shape: ", src.shape)
|
|
||||||
src = src[..., 0:1]
|
|
||||||
batch_size = src.shape[1]
|
|
||||||
src = src.reshape(self.input_window, batch_size, self.num_nodes * self.feature_dim)
|
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for i in range(self.output_window):
|
for i in range(self.output_window):
|
||||||
hidden_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)]
|
hidden_states, cell_states = [torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in
|
||||||
cell_states = [torch.zeros(batch_size, self.hidden_size).to(self.device) for _ in range(self.num_layers)]
|
range(self.num_layers)], \
|
||||||
|
[torch.zeros(batch_size, self.hidden_size, device=self.device) for _ in range(self.num_layers)]
|
||||||
|
bottom_layer_outputs, cell_states_history = [], [[] for _ in range(self.num_layers)]
|
||||||
|
|
||||||
bottom_layer_outputs = []
|
|
||||||
cell_states_history = [[] for _ in range(self.num_layers)]
|
|
||||||
for t in range(self.input_window):
|
for t in range(self.input_window):
|
||||||
hidden_states[0], cell_states[0] = self.lstm_cells[0](src[t], (hidden_states[0], cell_states[0]))
|
hidden_states[0], cell_states[0] = self.lstm_cells[0](src[t], (hidden_states[0], cell_states[0]))
|
||||||
bottom_layer_outputs.append(hidden_states[0])
|
bottom_layer_outputs.append(hidden_states[0])
|
||||||
cell_states_history[0].append(cell_states[0])
|
cell_states_history[0].append(cell_states[0])
|
||||||
|
|
||||||
bottom_layer_outputs = torch.stack(bottom_layer_outputs, dim=1)
|
bottom_layer_outputs, cell_states_history[0] = torch.stack(bottom_layer_outputs, 1), torch.stack(
|
||||||
cell_states_history[0] = torch.stack(cell_states_history[0], dim=1)
|
cell_states_history[0], 1)
|
||||||
|
|
||||||
for layer in range(1, self.num_layers):
|
for layer in range(1, self.num_layers):
|
||||||
layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs
|
layer_inputs = bottom_layer_outputs if layer == 1 else layer_outputs
|
||||||
layer_outputs = []
|
layer_outputs, cell_states_history[layer] = [], []
|
||||||
cell_states_history[layer] = []
|
for start, end in self.calculate_stride(layer_inputs.size(1)):
|
||||||
layer_strides = self.calculate_stride(layer_inputs.size(1))
|
segment, cell_segment = layer_inputs[:, start:end, :], cell_states_history[layer - 1][:, start:end,
|
||||||
|
:]
|
||||||
for start, end in layer_strides:
|
pooled_hidden, pooled_cell = self.hidden_state_pooling[layer - 1](segment), self.cell_state_pooling[
|
||||||
segment = layer_inputs[:, start:end, :]
|
layer - 1](torch.cat([cell_segment, cell_states[layer].unsqueeze(1)], 1))
|
||||||
cell_segment = cell_states_history[layer - 1][:, start:end, :]
|
|
||||||
|
|
||||||
pooled_hidden = self.hidden_state_pooling[layer - 1](segment)
|
|
||||||
pooled_cell = self.cell_state_pooling[layer - 1](
|
|
||||||
torch.cat([cell_segment, cell_states[layer].unsqueeze(1)], dim=1))
|
|
||||||
hidden_states[layer], cell_states[layer] = self.lstm_cells[layer](pooled_hidden, (
|
hidden_states[layer], cell_states[layer] = self.lstm_cells[layer](pooled_hidden, (
|
||||||
hidden_states[layer], pooled_cell))
|
hidden_states[layer], pooled_cell))
|
||||||
layer_outputs.append(hidden_states[layer])
|
layer_outputs.append(hidden_states[layer])
|
||||||
cell_states_history[layer].append(cell_states[layer])
|
cell_states_history[layer].append(cell_states[layer])
|
||||||
|
|
||||||
layer_outputs = torch.stack(layer_outputs, dim=1)
|
layer_outputs, cell_states_history[layer] = torch.stack(layer_outputs, 1), torch.stack(
|
||||||
cell_states_history[layer] = torch.stack(cell_states_history[layer], dim=1)
|
cell_states_history[layer], 1)
|
||||||
|
|
||||||
# print("layer_outputs shape: ", layer_outputs.shape) # [batch, sequence, hidden_size]
|
|
||||||
|
|
||||||
attended_features, _ = self.self_attention(layer_outputs)
|
attended_features, _ = self.self_attention(layer_outputs)
|
||||||
flattened = attended_features.view(batch_size, -1)
|
out = self.fc_layer(attended_features.view(batch_size, -1)).view(batch_size, self.num_nodes,
|
||||||
out = self.fc_layer(flattened)
|
self.output_dim)
|
||||||
out = out.view(batch_size, self.num_nodes, self.output_dim)
|
|
||||||
outputs.append(out.clone())
|
outputs.append(out.clone())
|
||||||
|
|
||||||
if i < self.output_window - 1:
|
if i < self.output_window - 1:
|
||||||
src = torch.cat(
|
src = torch.cat((src[1:], out.reshape(batch_size, -1).unsqueeze(0)), 0)
|
||||||
(src[1:, :, :], out.reshape(batch_size, self.num_nodes * self.feature_dim).unsqueeze(0)), dim=0)
|
|
||||||
|
|
||||||
outputs = torch.stack(outputs)
|
return torch.stack(outputs).permute(1, 0, 2, 3)
|
||||||
# outputs = [output_window, batch_size, num_nodes, output_dim]
|
|
||||||
return outputs.permute(1, 0, 2, 3)
|
|
||||||
|
|
||||||
def calculate_stride(self, sequence_len):
|
def calculate_stride(self, seq_len):
|
||||||
up_len = min(self.max_up_len, math.ceil(math.sqrt(sequence_len)))
|
idx = np.linspace(0, seq_len - 1, num=min(self.max_up_len, math.ceil(math.sqrt(seq_len))) + 3).astype(int)
|
||||||
idx = np.linspace(0, sequence_len - 1, num=up_len + 3).astype(int)
|
return list(zip(np.append(idx, seq_len - 1)[:-1], idx[1:]))
|
||||||
if idx[-1] != sequence_len - 1:
|
|
||||||
idx = np.append(idx, sequence_len - 1)
|
|
||||||
strides = list(zip(idx[:-1], idx[1:]))
|
|
||||||
return strides
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttentionPooling(nn.Module):
|
class SelfAttentionPooling(nn.Module):
|
||||||
def __init__(self, input_dim):
|
def __init__(self, input_dim):
|
||||||
super(SelfAttentionPooling, self).__init__()
|
super().__init__()
|
||||||
self.W = nn.Linear(input_dim, 1)
|
self.W = nn.Linear(input_dim, 1)
|
||||||
|
|
||||||
def forward(self, batch_rep):
|
def forward(self, batch_rep):
|
||||||
softmax = nn.functional.softmax
|
att_w = nn.functional.softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1)
|
||||||
att_w = softmax(self.W(batch_rep).squeeze(-1), dim=-1).unsqueeze(-1)
|
return torch.sum(batch_rep * att_w, dim=1)
|
||||||
utter_rep = torch.sum(batch_rep * att_w, dim=1)
|
|
||||||
return utter_rep
|
|
||||||
|
|
||||||
|
|
||||||
class SelfAttention(nn.Module):
|
class SelfAttention(nn.Module):
|
||||||
def __init__(self, attention_size, att_hops):
|
def __init__(self, att_size, att_hops):
|
||||||
super(SelfAttention, self).__init__()
|
super().__init__()
|
||||||
self.ut_dense = nn.Sequential(
|
self.ut_dense = nn.Sequential(nn.Linear(att_size, att_size), nn.Tanh())
|
||||||
nn.Linear(attention_size, attention_size),
|
self.et_dense, self.softmax = nn.Linear(att_size, att_hops), nn.Softmax(dim=-1)
|
||||||
nn.Tanh()
|
|
||||||
)
|
|
||||||
self.et_dense = nn.Linear(attention_size, att_hops)
|
|
||||||
self.softmax = nn.Softmax(dim=-1)
|
|
||||||
|
|
||||||
def forward(self, inputs):
|
def forward(self, inputs):
|
||||||
# inputs is a 3D Tensor: batch, len, hidden_size
|
att_scores = self.softmax(self.et_dense(self.ut_dense(inputs)).permute(0, 2, 1))
|
||||||
# scores is a 2D Tensor: batch, len
|
return torch.bmm(att_scores, inputs), att_scores
|
||||||
ut = self.ut_dense(inputs)
|
|
||||||
# et shape: [batch_size, seq_len, att_hops]
|
|
||||||
et = self.et_dense(ut)
|
|
||||||
att_scores = self.softmax(torch.permute(et, (0, 2, 1)))
|
|
||||||
output = torch.bmm(att_scores, inputs)
|
|
||||||
return output, att_scores
|
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ from model.STSGCN.STSGCN import STSGCN
|
||||||
from model.STGODE.STGODE import ODEGCN
|
from model.STGODE.STGODE import ODEGCN
|
||||||
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
||||||
from model.EXP.EXP import EXP
|
from model.EXP.EXP import EXP
|
||||||
|
from model.EXPB.EXP_b import EXPB
|
||||||
|
|
||||||
def model_selector(model):
|
def model_selector(model):
|
||||||
match model['type']:
|
match model['type']:
|
||||||
|
|
@ -33,4 +34,5 @@ def model_selector(model):
|
||||||
case 'STGODE': return ODEGCN(model)
|
case 'STGODE': return ODEGCN(model)
|
||||||
case 'PDG2SEQ': return PDG2Seq(model)
|
case 'PDG2SEQ': return PDG2Seq(model)
|
||||||
case 'EXP': return EXP(model)
|
case 'EXP': return EXP(model)
|
||||||
|
case 'EXPB': return EXPB(model)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue