TrafficWheel/model/GWN/GraphWaveNet_exp.py

257 lines
8.6 KiB
Python
Executable File

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import sys
class nconv(nn.Module):
def __init__(self):
super(nconv, self).__init__()
def forward(self, x, A):
x = torch.einsum("ncvl,vw->ncwl", (x, A))
return x.contiguous()
class linear(nn.Module):
def __init__(self, c_in, c_out):
super(linear, self).__init__()
self.mlp = torch.nn.Conv2d(
c_in, c_out, kernel_size=(1, 1), padding=(0, 0), stride=(1, 1), bias=True
)
def forward(self, x):
return self.mlp(x)
class gcn(nn.Module):
def __init__(self, c_in, c_out, dropout, support_len=3, order=2):
super(gcn, self).__init__()
self.nconv = nconv()
c_in = (order * support_len + 1) * c_in
self.mlp = linear(c_in, c_out)
self.dropout = dropout
self.order = order
def forward(self, x, support):
out = [x]
for a in support:
x1 = self.nconv(x, a)
out.append(x1)
for k in range(2, self.order + 1):
x2 = self.nconv(x1, a)
out.append(x2)
x1 = x2
h = torch.cat(out, dim=1)
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
class gwnet(nn.Module):
def __init__(self, args):
super(gwnet, self).__init__()
self.dropout = args["dropout"]
self.blocks = args["blocks"]
self.layers = args["layers"]
self.gcn_bool = args["gcn_bool"]
self.addaptadj = args["addaptadj"]
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.start_conv = nn.Conv2d(
in_channels=args["in_dim"],
out_channels=args["residual_channels"],
kernel_size=(1, 1),
)
self.supports = args.get("supports", None)
receptive_field = 1
self.supports_len = 0
if self.supports is not None:
self.supports_len += len(self.supports)
if self.gcn_bool and self.addaptadj:
aptinit = args.get("aptinit", None)
if aptinit is None:
if self.supports is None:
self.supports = []
self.nodevec1 = nn.Parameter(
torch.randn(args["num_nodes"], 10).to(args["device"]),
requires_grad=True,
).to(args["device"])
self.nodevec2 = nn.Parameter(
torch.randn(10, args["num_nodes"]).to(args["device"]),
requires_grad=True,
).to(args["device"])
self.supports_len += 1
else:
if self.supports is None:
self.supports = []
m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True).to(
args["device"]
)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True).to(
args["device"]
)
self.supports_len += 1
kernel_size = args["kernel_size"]
residual_channels = args["residual_channels"]
dilation_channels = args["dilation_channels"]
kernel_size = args["kernel_size"]
skip_channels = args["skip_channels"]
end_channels = args["end_channels"]
out_dim = args["out_dim"]
dropout = args["dropout"]
for b in range(self.blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(self.layers):
# dilated convolutions
self.filter_convs.append(
nn.Conv2d(
in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=(1, kernel_size),
dilation=new_dilation,
)
)
self.gate_convs.append(
nn.Conv2d(
in_channels=residual_channels,
out_channels=dilation_channels,
kernel_size=(1, kernel_size),
dilation=new_dilation,
)
)
# 1x1 convolution for residual connection
self.residual_convs.append(
nn.Conv2d(
in_channels=dilation_channels,
out_channels=residual_channels,
kernel_size=(1, 1),
)
)
# 1x1 convolution for skip connection
self.skip_convs.append(
nn.Conv2d(
in_channels=dilation_channels,
out_channels=skip_channels,
kernel_size=(1, 1),
)
)
self.bn.append(nn.BatchNorm2d(residual_channels))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool:
self.gconv.append(
gcn(
dilation_channels,
residual_channels,
dropout,
support_len=self.supports_len,
)
)
self.end_conv_1 = nn.Conv2d(
in_channels=skip_channels,
out_channels=end_channels,
kernel_size=(1, 1),
bias=True,
)
self.end_conv_2 = nn.Conv2d(
in_channels=end_channels,
out_channels=out_dim,
kernel_size=(1, 1),
bias=True,
)
self.receptive_field = receptive_field
def forward(self, input):
input = input[..., 0:2]
input = input.transpose(1, 3)
input = nn.functional.pad(input, (1, 0, 0, 0))
in_len = input.size(3)
if in_len < self.receptive_field:
x = nn.functional.pad(input, (self.receptive_field - in_len, 0, 0, 0))
else:
x = input
x = self.start_conv(x)
skip = 0
# calculate the current adaptive adj matrix once per iteration
new_supports = None
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
# WaveNet layers
for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
# (dilation, init_dilation) = self.dilations[i]
# residual = dilation_func(x, dilation, init_dilation, i)
residual = x
# dilated convolution
filter = self.filter_convs[i](residual)
filter = torch.tanh(filter)
gate = self.gate_convs[i](residual)
gate = torch.sigmoid(gate)
x = filter * gate
# parametrized skip connection
s = x
s = self.skip_convs[i](s)
try:
skip = skip[:, :, :, -s.size(3) :]
except:
skip = 0
skip = s + skip
if self.gcn_bool and self.supports is not None:
if self.addaptadj:
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x, self.supports)
else:
x = self.residual_convs[i](x)
y = residual[:, :, :, -x.size(3) :] # 从倒数第x.size(3)个到末尾
x = x + y
x = self.bn[i](x)
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x