396 lines
12 KiB
Python
Executable File
396 lines
12 KiB
Python
Executable File
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.init as init
|
|
|
|
|
|
class Align(nn.Module):
|
|
def __init__(self, c_in, c_out):
|
|
super(Align, self).__init__()
|
|
self.c_in = c_in
|
|
self.c_out = c_out
|
|
self.align_conv = nn.Conv2d(
|
|
in_channels=c_in, out_channels=c_out, kernel_size=(1, 1)
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.c_in > self.c_out:
|
|
x = self.align_conv(x)
|
|
elif self.c_in < self.c_out:
|
|
batch_size, _, timestep, n_vertex = x.shape
|
|
x = torch.cat(
|
|
[
|
|
x,
|
|
torch.zeros(
|
|
[batch_size, self.c_out - self.c_in, timestep, n_vertex]
|
|
).to(x),
|
|
],
|
|
dim=1,
|
|
)
|
|
else:
|
|
x = x
|
|
|
|
return x
|
|
|
|
|
|
class CausalConv1d(nn.Conv1d):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
enable_padding=False,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
if enable_padding == True:
|
|
self.__padding = (kernel_size - 1) * dilation
|
|
else:
|
|
self.__padding = 0
|
|
super(CausalConv1d, self).__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=self.__padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias,
|
|
)
|
|
|
|
def forward(self, input):
|
|
result = super(CausalConv1d, self).forward(input)
|
|
if self.__padding != 0:
|
|
return result[:, :, : -self.__padding]
|
|
|
|
return result
|
|
|
|
|
|
class CausalConv2d(nn.Conv2d):
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=1,
|
|
enable_padding=False,
|
|
dilation=1,
|
|
groups=1,
|
|
bias=True,
|
|
):
|
|
kernel_size = nn.modules.utils._pair(kernel_size)
|
|
stride = nn.modules.utils._pair(stride)
|
|
dilation = nn.modules.utils._pair(dilation)
|
|
if enable_padding == True:
|
|
self.__padding = [
|
|
int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))
|
|
]
|
|
else:
|
|
self.__padding = 0
|
|
self.left_padding = nn.modules.utils._pair(self.__padding)
|
|
super(CausalConv2d, self).__init__(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride=stride,
|
|
padding=0,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=bias,
|
|
)
|
|
|
|
def forward(self, input):
|
|
if self.__padding != 0:
|
|
input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0))
|
|
result = super(CausalConv2d, self).forward(input)
|
|
|
|
return result
|
|
|
|
|
|
class TemporalConvLayer(nn.Module):
|
|
# Temporal Convolution Layer (GLU)
|
|
#
|
|
# |--------------------------------| * residual connection *
|
|
# | |
|
|
# | |--->--- casualconv2d ----- + -------|
|
|
# -------|----| ⊙ ------>
|
|
# |--->--- casualconv2d --- sigmoid ---|
|
|
#
|
|
|
|
# param x: tensor, [bs, c_in, ts, n_vertex]
|
|
|
|
def __init__(self, Kt, c_in, c_out, n_vertex, act_func):
|
|
super(TemporalConvLayer, self).__init__()
|
|
self.Kt = Kt
|
|
self.c_in = c_in
|
|
self.c_out = c_out
|
|
self.n_vertex = n_vertex
|
|
self.align = Align(c_in, c_out)
|
|
if act_func == "glu" or act_func == "gtu":
|
|
self.causal_conv = CausalConv2d(
|
|
in_channels=c_in,
|
|
out_channels=2 * c_out,
|
|
kernel_size=(Kt, 1),
|
|
enable_padding=False,
|
|
dilation=1,
|
|
)
|
|
else:
|
|
self.causal_conv = CausalConv2d(
|
|
in_channels=c_in,
|
|
out_channels=c_out,
|
|
kernel_size=(Kt, 1),
|
|
enable_padding=False,
|
|
dilation=1,
|
|
)
|
|
self.relu = nn.ReLU()
|
|
self.silu = nn.SiLU()
|
|
self.act_func = act_func
|
|
|
|
def forward(self, x):
|
|
x_in = self.align(x)[:, :, self.Kt - 1 :, :]
|
|
x_causal_conv = self.causal_conv(x)
|
|
|
|
if self.act_func == "glu" or self.act_func == "gtu":
|
|
x_p = x_causal_conv[:, : self.c_out, :, :]
|
|
x_q = x_causal_conv[:, -self.c_out :, :, :]
|
|
|
|
if self.act_func == "glu":
|
|
# Explanation of Gated Linear Units (GLU):
|
|
# The concept of GLU was first introduced in the paper
|
|
# "Language Modeling with Gated Convolutional Networks".
|
|
# URL: https://arxiv.org/abs/1612.08083
|
|
# In the GLU operation, the input tensor X is divided into two tensors, X_a and X_b,
|
|
# along a specific dimension.
|
|
# In PyTorch, GLU is computed as the element-wise multiplication of X_a and sigmoid(X_b).
|
|
# More information can be found here: https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.glu
|
|
# The provided code snippet, (x_p + x_in) ⊙ sigmoid(x_q), is an example of GLU operation.
|
|
x = torch.mul((x_p + x_in), torch.sigmoid(x_q))
|
|
|
|
else:
|
|
# tanh(x_p + x_in) ⊙ sigmoid(x_q)
|
|
x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q))
|
|
|
|
elif self.act_func == "relu":
|
|
x = self.relu(x_causal_conv + x_in)
|
|
|
|
elif self.act_func == "silu":
|
|
x = self.silu(x_causal_conv + x_in)
|
|
|
|
else:
|
|
raise NotImplementedError(
|
|
f"ERROR: The activation function {self.act_func} is not implemented."
|
|
)
|
|
|
|
return x
|
|
|
|
|
|
class ChebGraphConv(nn.Module):
|
|
def __init__(self, c_in, c_out, Ks, gso, bias):
|
|
super(ChebGraphConv, self).__init__()
|
|
self.c_in = c_in
|
|
self.c_out = c_out
|
|
self.Ks = Ks
|
|
self.gso = gso
|
|
self.weight = nn.Parameter(torch.FloatTensor(Ks, c_in, c_out))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.FloatTensor(c_out))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
if self.bias is not None:
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
def forward(self, x):
|
|
# bs, c_in, ts, n_vertex = x.shape
|
|
x = torch.permute(x, (0, 2, 3, 1))
|
|
|
|
if self.Ks - 1 < 0:
|
|
raise ValueError(
|
|
f"ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}."
|
|
)
|
|
elif self.Ks - 1 == 0:
|
|
x_0 = x
|
|
x_list = [x_0]
|
|
elif self.Ks - 1 == 1:
|
|
x_0 = x
|
|
x_1 = torch.einsum("hi,btij->bthj", self.gso, x)
|
|
x_list = [x_0, x_1]
|
|
elif self.Ks - 1 >= 2:
|
|
x_0 = x
|
|
x_1 = torch.einsum("hi,btij->bthj", self.gso, x)
|
|
x_list = [x_0, x_1]
|
|
for k in range(2, self.Ks):
|
|
x_list.append(
|
|
torch.einsum("hi,btij->bthj", 2 * self.gso, x_list[k - 1])
|
|
- x_list[k - 2]
|
|
)
|
|
|
|
x = torch.stack(x_list, dim=2)
|
|
|
|
cheb_graph_conv = torch.einsum("btkhi,kij->bthj", x, self.weight)
|
|
|
|
if self.bias is not None:
|
|
cheb_graph_conv = torch.add(cheb_graph_conv, self.bias)
|
|
else:
|
|
cheb_graph_conv = cheb_graph_conv
|
|
|
|
return cheb_graph_conv
|
|
|
|
|
|
class GraphConv(nn.Module):
|
|
def __init__(self, c_in, c_out, gso, bias):
|
|
super(GraphConv, self).__init__()
|
|
self.c_in = c_in
|
|
self.c_out = c_out
|
|
self.gso = gso
|
|
self.weight = nn.Parameter(torch.FloatTensor(c_in, c_out))
|
|
if bias:
|
|
self.bias = nn.Parameter(torch.FloatTensor(c_out))
|
|
else:
|
|
self.register_parameter("bias", None)
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
if self.bias is not None:
|
|
fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
init.uniform_(self.bias, -bound, bound)
|
|
|
|
def forward(self, x):
|
|
# bs, c_in, ts, n_vertex = x.shape
|
|
x = torch.permute(x, (0, 2, 3, 1))
|
|
|
|
first_mul = torch.einsum("hi,btij->bthj", self.gso, x)
|
|
second_mul = torch.einsum("bthi,ij->bthj", first_mul, self.weight)
|
|
|
|
if self.bias is not None:
|
|
graph_conv = torch.add(second_mul, self.bias)
|
|
else:
|
|
graph_conv = second_mul
|
|
|
|
return graph_conv
|
|
|
|
|
|
class GraphConvLayer(nn.Module):
|
|
def __init__(self, graph_conv_type, c_in, c_out, Ks, gso, bias):
|
|
super(GraphConvLayer, self).__init__()
|
|
self.graph_conv_type = graph_conv_type
|
|
self.c_in = c_in
|
|
self.c_out = c_out
|
|
self.align = Align(c_in, c_out)
|
|
self.Ks = Ks
|
|
self.gso = gso
|
|
if self.graph_conv_type == "cheb_graph_conv":
|
|
self.cheb_graph_conv = ChebGraphConv(c_out, c_out, Ks, gso, bias)
|
|
elif self.graph_conv_type == "graph_conv":
|
|
self.graph_conv = GraphConv(c_out, c_out, gso, bias)
|
|
|
|
def forward(self, x):
|
|
x_gc_in = self.align(x)
|
|
if self.graph_conv_type == "cheb_graph_conv":
|
|
x_gc = self.cheb_graph_conv(x_gc_in)
|
|
elif self.graph_conv_type == "graph_conv":
|
|
x_gc = self.graph_conv(x_gc_in)
|
|
x_gc = x_gc.permute(0, 3, 1, 2)
|
|
x_gc_out = torch.add(x_gc, x_gc_in)
|
|
|
|
return x_gc_out
|
|
|
|
|
|
class STConvBlock(nn.Module):
|
|
# STConv Block contains 'TGTND' structure
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# G: Graph Convolution Layer (ChebGraphConv or GraphConv)
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# D: Dropout
|
|
|
|
def __init__(
|
|
self,
|
|
Kt,
|
|
Ks,
|
|
n_vertex,
|
|
last_block_channel,
|
|
channels,
|
|
act_func,
|
|
graph_conv_type,
|
|
gso,
|
|
bias,
|
|
droprate,
|
|
):
|
|
super(STConvBlock, self).__init__()
|
|
self.tmp_conv1 = TemporalConvLayer(
|
|
Kt, last_block_channel, channels[0], n_vertex, act_func
|
|
)
|
|
self.graph_conv = GraphConvLayer(
|
|
graph_conv_type, channels[0], channels[1], Ks, gso, bias
|
|
)
|
|
self.tmp_conv2 = TemporalConvLayer(
|
|
Kt, channels[1], channels[2], n_vertex, act_func
|
|
)
|
|
self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(p=droprate)
|
|
|
|
def forward(self, x):
|
|
x = self.tmp_conv1(x)
|
|
x = self.graph_conv(x)
|
|
x = self.relu(x)
|
|
x = self.tmp_conv2(x)
|
|
x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
x = self.dropout(x)
|
|
|
|
return x
|
|
|
|
|
|
class OutputBlock(nn.Module):
|
|
# Output block contains 'TNFF' structure
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# F: Fully-Connected Layer
|
|
# F: Fully-Connected Layer
|
|
|
|
def __init__(
|
|
self,
|
|
Ko,
|
|
last_block_channel,
|
|
channels,
|
|
end_channel,
|
|
n_vertex,
|
|
act_func,
|
|
bias,
|
|
droprate,
|
|
):
|
|
super(OutputBlock, self).__init__()
|
|
self.tmp_conv1 = TemporalConvLayer(
|
|
Ko, last_block_channel, channels[0], n_vertex, act_func
|
|
)
|
|
self.fc1 = nn.Linear(
|
|
in_features=channels[0], out_features=channels[1], bias=bias
|
|
)
|
|
self.fc2 = nn.Linear(
|
|
in_features=channels[1], out_features=end_channel, bias=bias
|
|
)
|
|
self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(p=droprate)
|
|
|
|
def forward(self, x):
|
|
x = self.tmp_conv1(x)
|
|
x = self.tc1_ln(x.permute(0, 2, 3, 1))
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.dropout(x)
|
|
x = self.fc2(x).permute(0, 3, 1, 2)
|
|
|
|
return x
|