TrafficWheel/model/STEP/graphwavenet.py

225 lines
8.8 KiB
Python

import torch
from torch import nn
import torch.nn.functional as F
class nconv(nn.Module):
def __init__(self):
super(nconv,self).__init__()
def forward(self,x, A):
A = A.to(x.device)
if len(A.shape) == 3:
x = torch.einsum('ncvl,nvw->ncwl',(x,A))
else:
x = torch.einsum('ncvl,vw->ncwl',(x,A))
return x.contiguous()
class linear(nn.Module):
def __init__(self,c_in,c_out):
super(linear,self).__init__()
self.mlp = torch.nn.Conv2d(c_in, c_out, kernel_size=(1, 1), padding=(0,0), stride=(1,1), bias=True)
def forward(self,x):
return self.mlp(x)
class gcn(nn.Module):
def __init__(self,c_in,c_out,dropout,support_len=3,order=2):
super(gcn,self).__init__()
self.nconv = nconv()
c_in = (order*support_len+1)*c_in
self.mlp = linear(c_in,c_out)
self.dropout = dropout
self.order = order
def forward(self,x,support):
out = [x]
for a in support:
x1 = self.nconv(x,a)
out.append(x1)
for k in range(2, self.order + 1):
x2 = self.nconv(x1,a)
out.append(x2)
x1 = x2
h = torch.cat(out,dim=1)
h = self.mlp(h)
h = F.dropout(h, self.dropout, training=self.training)
return h
class GraphWaveNet(nn.Module):
"""
Paper: Graph WaveNet for Deep Spatial-Temporal Graph Modeling.
Link: https://arxiv.org/abs/1906.00121
Ref Official Code: https://github.com/nnzhan/Graph-WaveNet/blob/master/model.py
"""
def __init__(self, num_nodes, support_len, dropout=0.3, gcn_bool=True, addaptadj=True, aptinit=None, in_dim=2,out_dim=12,residual_channels=32,dilation_channels=32,skip_channels=256,end_channels=512,kernel_size=2,blocks=4,layers=2, **kwargs):
"""
kindly note that although there is a 'supports' parameter, we will not use the prior graph if there is a learned dependency graph.
Details can be found in the feed forward function.
"""
super(GraphWaveNet, self).__init__()
self.dropout = dropout
self.blocks = blocks
self.layers = layers
self.gcn_bool = gcn_bool
self.addaptadj = addaptadj
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.bn = nn.ModuleList()
self.gconv = nn.ModuleList()
self.fc_his = nn.Sequential(nn.Linear(96, 512), nn.ReLU(), nn.Linear(512, 256), nn.ReLU())
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=residual_channels, kernel_size=(1,1))
receptive_field = 1
self.supports_len = support_len
if gcn_bool and addaptadj:
if aptinit is None:
self.nodevec1 = nn.Parameter(torch.randn(num_nodes, 10), requires_grad=True)
self.nodevec2 = nn.Parameter(torch.randn(10, num_nodes), requires_grad=True)
self.supports_len +=1
else:
m, p, n = torch.svd(aptinit)
initemb1 = torch.mm(m[:, :10], torch.diag(p[:10] ** 0.5))
initemb2 = torch.mm(torch.diag(p[:10] ** 0.5), n[:, :10].t())
self.nodevec1 = nn.Parameter(initemb1, requires_grad=True)
self.nodevec2 = nn.Parameter(initemb2, requires_grad=True)
self.supports_len += 1
for b in range(blocks):
additional_scope = kernel_size - 1
new_dilation = 1
for i in range(layers):
# dilated convolutions
self.filter_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1,kernel_size),dilation=new_dilation))
self.gate_convs.append(nn.Conv2d(in_channels=residual_channels, out_channels=dilation_channels, kernel_size=(1, kernel_size), dilation=new_dilation))
# 1x1 convolution for residual connection
self.residual_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=residual_channels, kernel_size=(1, 1)))
# 1x1 convolution for skip connection
self.skip_convs.append(nn.Conv2d(in_channels=dilation_channels, out_channels=skip_channels, kernel_size=(1, 1)))
self.bn.append(nn.BatchNorm2d(residual_channels))
new_dilation *= 2
receptive_field += additional_scope
additional_scope *= 2
if self.gcn_bool:
self.gconv.append(gcn(dilation_channels,residual_channels,dropout,support_len=self.supports_len))
self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, out_channels=end_channels, kernel_size=(1,1), bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=end_channels, out_channels=out_dim, kernel_size=(1,1), bias=True)
self.receptive_field = receptive_field
def _calculate_random_walk_matrix(self, adj_mx):
B, N, N = adj_mx.shape
adj_mx = adj_mx + torch.eye(int(adj_mx.shape[1])).unsqueeze(0).expand(B, N, N).to(adj_mx.device)
d = torch.sum(adj_mx, 2)
d_inv = 1. / d
d_inv = torch.where(torch.isinf(d_inv), torch.zeros(d_inv.shape).to(adj_mx.device), d_inv)
d_mat_inv = torch.diag_embed(d_inv)
random_walk_mx = torch.bmm(d_mat_inv, adj_mx)
return random_walk_mx
def forward(self, input, hidden_states, sampled_adj):
"""feed forward of Graph WaveNet.
Args:
input (torch.Tensor): input history MTS with shape [B, L, N, C].
His (torch.Tensor): the output of TSFormer of the last patch (segment) with shape [B, N, d].
adj (torch.Tensor): the learned discrete dependency graph with shape [B, N, N].
Returns:
torch.Tensor: prediction with shape [B, N, L]
"""
# reshape input: [B, L, N, C] -> [B, C, N, L]
input = input.transpose(1, 3)
# feed forward
input = nn.functional.pad(input,(1,0,0,0))
input = input[:, :2, :, :]
in_len = input.size(3)
if in_len<self.receptive_field:
x = nn.functional.pad(input,(self.receptive_field-in_len,0,0,0))
else:
x = input
x = self.start_conv(x)
skip = 0
#
# ====== if use learned adjacency matrix, then reset the self.supports ===== #
self.supports = [self._calculate_random_walk_matrix(sampled_adj), self._calculate_random_walk_matrix(sampled_adj.transpose(-1, -2))]
# calculate the current adaptive adj matrix
new_supports = None
if self.gcn_bool and self.addaptadj and self.supports is not None:
adp = F.softmax(F.relu(torch.mm(self.nodevec1, self.nodevec2)), dim=1)
new_supports = self.supports + [adp]
# WaveNet layers
for i in range(self.blocks * self.layers):
# |----------------------------------------| *residual*
# | |
# | |-- conv -- tanh --| |
# -> dilate -|----| * ----|-- 1x1 -- + --> *input*
# |-- conv -- sigm --| |
# 1x1
# |
# ---------------------------------------> + -------------> *skip*
#(dilation, init_dilation) = self.dilations[i]
#residual = dilation_func(x, dilation, init_dilation, i)
residual = x
# dilated convolution
filter = self.filter_convs[i](residual)
filter = torch.tanh(filter)
gate = self.gate_convs[i](residual)
gate = torch.sigmoid(gate)
x = filter * gate
# parametrized skip connection
s = x
s = self.skip_convs[i](s)
try:
skip = skip[:, :, :, -s.size(3):]
except:
skip = 0
skip = s + skip
if self.gcn_bool and self.supports is not None:
if self.addaptadj:
x = self.gconv[i](x, new_supports)
else:
x = self.gconv[i](x,self.supports)
else:
x = self.residual_convs[i](x)
x = x + residual[:, :, :, -x.size(3):]
x = self.bn[i](x)
hidden_states = self.fc_his(hidden_states) # B, N, D
hidden_states = hidden_states.transpose(1, 2).unsqueeze(-1)
skip = skip + hidden_states
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
# reshape output: [B, P, N, 1] -> [B, N, P]
x = x.squeeze(-1).transpose(1, 2)
return x