162 lines
6.7 KiB
Python
Executable File
162 lines
6.7 KiB
Python
Executable File
import torch
|
|
import torch.nn as nn
|
|
from torch.nn import BatchNorm2d, Conv1d, Conv2d, ModuleList, Parameter
|
|
import torch.nn.functional as F
|
|
|
|
def nconv(x, A):
|
|
"""Multiply x by adjacency matrix along source node axis"""
|
|
return torch.einsum('ncvl,vw->ncwl', (x, A)).contiguous()
|
|
|
|
|
|
class GraphConvNet(nn.Module):
|
|
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
|
|
super().__init__()
|
|
c_in = (order * support_len + 1) * c_in
|
|
self.final_conv = Conv2d(c_in, c_out, (1, 1), padding=(0, 0), stride=(1, 1), bias=True)
|
|
self.dropout = dropout
|
|
self.order = order
|
|
|
|
def forward(self, x, support: list):
|
|
out = [x]
|
|
for a in support:
|
|
x1 = nconv(x, a)
|
|
out.append(x1)
|
|
for k in range(2, self.order + 1):
|
|
x2 = nconv(x1, a)
|
|
out.append(x2)
|
|
x1 = x2
|
|
|
|
h = torch.cat(out, dim=1)
|
|
h = self.final_conv(h)
|
|
h = F.dropout(h, self.dropout, training=self.training)
|
|
return h
|
|
|
|
|
|
class gwnet(nn.Module):
|
|
"""
|
|
Graph WaveNet模型的主类
|
|
结合了图卷积网络和时序卷积网络,用于时空预测任务
|
|
"""
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
# 初始化基本参数
|
|
self.dropout = args["dropout"]
|
|
self.blocks = args["blocks"]
|
|
self.layers = args["layers"]
|
|
self.do_graph_conv = args.get("do_graph_conv", True)
|
|
self.cat_feat_gc = args.get("cat_feat_gc", False)
|
|
self.addaptadj = args.get("addaptadj", True)
|
|
supports = None
|
|
aptinit = args.get("aptinit", None)
|
|
in_dim = args.get("in_dim")
|
|
out_dim = args.get("out_dim")
|
|
residual_channels = args.get("residual_channels")
|
|
dilation_channels = args.get("dilation_channels")
|
|
skip_channels = args.get("skip_channels")
|
|
end_channels = args.get("end_channels")
|
|
kernel_size = args.get("kernel_size")
|
|
apt_size = args.get("apt_size", 10)
|
|
|
|
|
|
if self.cat_feat_gc:
|
|
self.start_conv = nn.Conv2d(in_channels=1, # hard code to avoid errors
|
|
out_channels=residual_channels,
|
|
kernel_size=(1, 1))
|
|
self.cat_feature_conv = nn.Conv2d(in_channels=in_dim - 1,
|
|
out_channels=residual_channels,
|
|
kernel_size=(1, 1))
|
|
else:
|
|
self.start_conv = nn.Conv2d(in_channels=in_dim,
|
|
out_channels=residual_channels,
|
|
kernel_size=(1, 1))
|
|
|
|
self.fixed_supports = supports or []
|
|
receptive_field = 1
|
|
|
|
self.supports_len = len(self.fixed_supports)
|
|
if self.do_graph_conv and self.addaptadj:
|
|
if aptinit is None:
|
|
nodevecs = torch.randn(args["num_nodes"], apt_size), torch.randn(apt_size, args["num_nodes"])
|
|
else:
|
|
nodevecs = self.svd_init(args["num_nodes"], apt_size, aptinit)
|
|
self.supports_len += 1
|
|
self.nodevec1, self.nodevec2 = [Parameter(n.to(args["device"]), requires_grad=True) for n in nodevecs]
|
|
|
|
depth = list(range(self.blocks * self.layers))
|
|
|
|
# 1x1 convolution for residual and skip connections (slightly different see docstring)
|
|
self.residual_convs = ModuleList([Conv2d(dilation_channels, residual_channels, (1, 1)) for _ in depth])
|
|
self.skip_convs = ModuleList([Conv2d(dilation_channels, skip_channels, (1, 1)) for _ in depth])
|
|
self.bn = ModuleList([BatchNorm2d(residual_channels) for _ in depth])
|
|
self.graph_convs = ModuleList([GraphConvNet(dilation_channels, residual_channels, self.dropout, support_len=self.supports_len)
|
|
for _ in depth])
|
|
|
|
self.filter_convs = ModuleList()
|
|
self.gate_convs = ModuleList()
|
|
for b in range(self.blocks):
|
|
additional_scope = kernel_size - 1
|
|
D = 1 # dilation
|
|
for i in range(self.layers):
|
|
# dilated convolutions
|
|
self.filter_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
|
self.gate_convs.append(Conv2d(residual_channels, dilation_channels, (1, kernel_size), dilation=D))
|
|
D *= 2
|
|
receptive_field += additional_scope
|
|
additional_scope *= 2
|
|
self.receptive_field = receptive_field
|
|
|
|
self.end_conv_1 = Conv2d(skip_channels, end_channels, (1, 1), bias=True)
|
|
self.end_conv_2 = Conv2d(end_channels, out_dim, (1, 1), bias=True)
|
|
|
|
def forward(self, input):
|
|
x = input[..., 0:1].transpose(1, 3)
|
|
# Input shape is (bs, features, n_nodes, n_timesteps)
|
|
in_len = x.size(3)
|
|
if in_len < self.receptive_field:
|
|
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
|
|
if self.cat_feat_gc:
|
|
f1, f2 = x[:, [0]], x[:, 1:]
|
|
x1 = self.start_conv(f1)
|
|
x2 = F.leaky_relu(self.cat_feature_conv(f2))
|
|
x = x1 + x2
|
|
else:
|
|
x = self.start_conv(x)
|
|
skip = 0
|
|
adjacency_matrices = self.fixed_supports
|
|
# calculate the current adaptive adj matrix once per iteration
|
|
if self.addaptadj:
|
|
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
|
|
adjacency_matrices = self.fixed_supports + [adp]
|
|
|
|
# WaveNet layers
|
|
for i in range(self.blocks * self.layers):
|
|
residual = x
|
|
# dilated convolution
|
|
filter = torch.tanh(self.filter_convs[i](residual))
|
|
gate = torch.sigmoid(self.gate_convs[i](residual))
|
|
x = filter * gate
|
|
# parametrized skip connection
|
|
s = self.skip_convs[i](x) # what are we skipping??
|
|
try: # if i > 0 this works
|
|
skip = skip[:, :, :, -s.size(3):] # TODO(SS): Mean/Max Pool?
|
|
except:
|
|
skip = 0
|
|
skip = s + skip
|
|
if i == (self.blocks * self.layers - 1): # last X getting ignored anyway
|
|
break
|
|
|
|
if self.do_graph_conv:
|
|
graph_out = self.graph_convs[i](x, adjacency_matrices)
|
|
x = x + graph_out if self.cat_feat_gc else graph_out
|
|
else:
|
|
x = self.residual_convs[i](x)
|
|
x = x + residual[:, :, :, -x.size(3):] # TODO(SS): Mean/Max Pool?
|
|
x = self.bn[i](x)
|
|
|
|
x = F.relu(skip) # ignore last X?
|
|
x = F.relu(self.end_conv_1(x))
|
|
x = self.end_conv_2(x) # downsample to (bs, seq_length, 207, nfeatures)
|
|
# x = x.transpose(1, 3)
|
|
return x
|