TrafficWheel/model/STGODE/STGODE.py

180 lines
6.4 KiB
Python
Executable File

import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from model.STGODE.odegcn import ODEG
from model.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
"""
extra dimension will be added by padding, remove it
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
"""
time dilation convolution
"""
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
"""
Args:
num_inputs : channel's number of input data's feature
num_channels : numbers of data feature tranform channels, the last is the output channel
kernel_size : using 1d convolution, so the real kernel is (1, kernel_size)
"""
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
"""
like ResNet
Args:
X : input data of shape (B, N, T, F)
"""
# permute shape to (B, F, N, T)
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class GCN(nn.Module):
def __init__(self, A_hat, in_channels, out_channels, ):
super(GCN, self).__init__()
self.A_hat = A_hat
self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
self.reset()
def reset(self):
stdv = 1. / math.sqrt(self.theta.shape[1])
self.theta.data.uniform_(-stdv, stdv)
def forward(self, X):
y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X)
return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta))
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
"""
Args:
in_channels: Number of input features at each node in each time step.
out_channels: a list of feature channels in timeblock, the last is output feature channel
num_nodes: Number of nodes in the graph
A_hat: the normalized adjacency matrix
"""
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels,
num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1],
num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
"""
Args:
X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features)
Return:
Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1])
"""
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class ODEGCN(nn.Module):
""" the overall network framework """
def __init__(self, args):
"""
Args:
num_nodes : number of nodes in the graph
num_features : number of features at each node in each time step
num_timesteps_input : number of past time steps fed into the network
num_timesteps_output : desired number of future time steps output by the network
A_sp_hat : nomarlized adjacency spatial matrix
A_se_hat : nomarlized adjacency semantic matrix
"""
super(ODEGCN, self).__init__()
num_nodes = args['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(args)
# spatial graph
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
# semantic graph
self.se_blocks = nn.ModuleList([nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
self.pred = nn.Sequential(
nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32),
nn.ReLU(),
nn.Linear(num_timesteps_output * 32, num_timesteps_output)
)
def forward(self, x):
"""
Args:
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
Returns:
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
"""
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
# spatial graph
for blk in self.sp_blocks:
outs.append(blk(x))
# semantic graph
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0]
x = x.reshape((x.shape[0], x.shape[1], -1))
return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)