TrafficWheel/model/STGCN/layers.py

285 lines
11 KiB
Python
Executable File

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
class Align(nn.Module):
def __init__(self, c_in, c_out):
super(Align, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.align_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=(1, 1))
def forward(self, x):
if self.c_in > self.c_out:
x = self.align_conv(x)
elif self.c_in < self.c_out:
batch_size, _, timestep, n_vertex = x.shape
x = torch.cat([x, torch.zeros([batch_size, self.c_out - self.c_in, timestep, n_vertex]).to(x)], dim=1)
else:
x = x
return x
class CausalConv1d(nn.Conv1d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True):
if enable_padding == True:
self.__padding = (kernel_size - 1) * dilation
else:
self.__padding = 0
super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, groups=groups, bias=bias)
def forward(self, input):
result = super(CausalConv1d, self).forward(input)
if self.__padding != 0:
return result[: , : , : -self.__padding]
return result
class CausalConv2d(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True):
kernel_size = nn.modules.utils._pair(kernel_size)
stride = nn.modules.utils._pair(stride)
dilation = nn.modules.utils._pair(dilation)
if enable_padding == True:
self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))]
else:
self.__padding = 0
self.left_padding = nn.modules.utils._pair(self.__padding)
super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias)
def forward(self, input):
if self.__padding != 0:
input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))
result = super(CausalConv2d, self).forward(input)
return result
class TemporalConvLayer(nn.Module):
# Temporal Convolution Layer (GLU)
#
# |--------------------------------| * residual connection *
# | |
# | |--->--- casualconv2d ----- + -------|
# -------|----| ⊙ ------>
# |--->--- casualconv2d --- sigmoid ---|
#
#param x: tensor, [bs, c_in, ts, n_vertex]
def __init__(self, Kt, c_in, c_out, n_vertex, act_func):
super(TemporalConvLayer, self).__init__()
self.Kt = Kt
self.c_in = c_in
self.c_out = c_out
self.n_vertex = n_vertex
self.align = Align(c_in, c_out)
if act_func == 'glu' or act_func == 'gtu':
self.causal_conv = CausalConv2d(in_channels=c_in, out_channels=2 * c_out, kernel_size=(Kt, 1), enable_padding=False, dilation=1)
else:
self.causal_conv = CausalConv2d(in_channels=c_in, out_channels=c_out, kernel_size=(Kt, 1), enable_padding=False, dilation=1)
self.relu = nn.ReLU()
self.silu = nn.SiLU()
self.act_func = act_func
def forward(self, x):
x_in = self.align(x)[:, :, self.Kt - 1:, :]
x_causal_conv = self.causal_conv(x)
if self.act_func == 'glu' or self.act_func == 'gtu':
x_p = x_causal_conv[:, : self.c_out, :, :]
x_q = x_causal_conv[:, -self.c_out:, :, :]
if self.act_func == 'glu':
# Explanation of Gated Linear Units (GLU):
# The concept of GLU was first introduced in the paper
# "Language Modeling with Gated Convolutional Networks".
# URL: https://arxiv.org/abs/1612.08083
# In the GLU operation, the input tensor X is divided into two tensors, X_a and X_b,
# along a specific dimension.
# In PyTorch, GLU is computed as the element-wise multiplication of X_a and sigmoid(X_b).
# More information can be found here: https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.glu
# The provided code snippet, (x_p + x_in) ⊙ sigmoid(x_q), is an example of GLU operation.
x = torch.mul((x_p + x_in), torch.sigmoid(x_q))
else:
# tanh(x_p + x_in) ⊙ sigmoid(x_q)
x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q))
elif self.act_func == 'relu':
x = self.relu(x_causal_conv + x_in)
elif self.act_func == 'silu':
x = self.silu(x_causal_conv + x_in)
else:
raise NotImplementedError(f'ERROR: The activation function {self.act_func} is not implemented.')
return x
class ChebGraphConv(nn.Module):
def __init__(self, c_in, c_out, Ks, gso, bias):
super(ChebGraphConv, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.Ks = Ks
self.gso = gso
self.weight = nn.Parameter(torch.FloatTensor(Ks, c_in, c_out))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(c_out))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, x):
#bs, c_in, ts, n_vertex = x.shape
x = torch.permute(x, (0, 2, 3, 1))
if self.Ks - 1 < 0:
raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.')
elif self.Ks - 1 == 0:
x_0 = x
x_list = [x_0]
elif self.Ks - 1 == 1:
x_0 = x
x_1 = torch.einsum('hi,btij->bthj', self.gso, x)
x_list = [x_0, x_1]
elif self.Ks - 1 >= 2:
x_0 = x
x_1 = torch.einsum('hi,btij->bthj', self.gso, x)
x_list = [x_0, x_1]
for k in range(2, self.Ks):
x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2])
x = torch.stack(x_list, dim=2)
cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight)
if self.bias is not None:
cheb_graph_conv = torch.add(cheb_graph_conv, self.bias)
else:
cheb_graph_conv = cheb_graph_conv
return cheb_graph_conv
class GraphConv(nn.Module):
def __init__(self, c_in, c_out, gso, bias):
super(GraphConv, self).__init__()
self.c_in = c_in
self.c_out = c_out
self.gso = gso
self.weight = nn.Parameter(torch.FloatTensor(c_in, c_out))
if bias:
self.bias = nn.Parameter(torch.FloatTensor(c_out))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
if self.bias is not None:
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
init.uniform_(self.bias, -bound, bound)
def forward(self, x):
#bs, c_in, ts, n_vertex = x.shape
x = torch.permute(x, (0, 2, 3, 1))
first_mul = torch.einsum('hi,btij->bthj', self.gso, x)
second_mul = torch.einsum('bthi,ij->bthj', first_mul, self.weight)
if self.bias is not None:
graph_conv = torch.add(second_mul, self.bias)
else:
graph_conv = second_mul
return graph_conv
class GraphConvLayer(nn.Module):
def __init__(self, graph_conv_type, c_in, c_out, Ks, gso, bias):
super(GraphConvLayer, self).__init__()
self.graph_conv_type = graph_conv_type
self.c_in = c_in
self.c_out = c_out
self.align = Align(c_in, c_out)
self.Ks = Ks
self.gso = gso
if self.graph_conv_type == 'cheb_graph_conv':
self.cheb_graph_conv = ChebGraphConv(c_out, c_out, Ks, gso, bias)
elif self.graph_conv_type == 'graph_conv':
self.graph_conv = GraphConv(c_out, c_out, gso, bias)
def forward(self, x):
x_gc_in = self.align(x)
if self.graph_conv_type == 'cheb_graph_conv':
x_gc = self.cheb_graph_conv(x_gc_in)
elif self.graph_conv_type == 'graph_conv':
x_gc = self.graph_conv(x_gc_in)
x_gc = x_gc.permute(0, 3, 1, 2)
x_gc_out = torch.add(x_gc, x_gc_in)
return x_gc_out
class STConvBlock(nn.Module):
# STConv Block contains 'TGTND' structure
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebGraphConv or GraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, act_func, graph_conv_type, gso, bias, droprate):
super(STConvBlock, self).__init__()
self.tmp_conv1 = TemporalConvLayer(Kt, last_block_channel, channels[0], n_vertex, act_func)
self.graph_conv = GraphConvLayer(graph_conv_type, channels[0], channels[1], Ks, gso, bias)
self.tmp_conv2 = TemporalConvLayer(Kt, channels[1], channels[2], n_vertex, act_func)
self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=droprate)
def forward(self, x):
x = self.tmp_conv1(x)
x = self.graph_conv(x)
x = self.relu(x)
x = self.tmp_conv2(x)
x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
x = self.dropout(x)
return x
class OutputBlock(nn.Module):
# Output block contains 'TNFF' structure
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, act_func, bias, droprate):
super(OutputBlock, self).__init__()
self.tmp_conv1 = TemporalConvLayer(Ko, last_block_channel, channels[0], n_vertex, act_func)
self.fc1 = nn.Linear(in_features=channels[0], out_features=channels[1], bias=bias)
self.fc2 = nn.Linear(in_features=channels[1], out_features=end_channel, bias=bias)
self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=droprate)
def forward(self, x):
x = self.tmp_conv1(x)
x = self.tc1_ln(x.permute(0, 2, 3, 1))
x = self.fc1(x)
x = self.relu(x)
x = self.dropout(x)
x = self.fc2(x).permute(0, 3, 1, 2)
return x