import math import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init class Align(nn.Module): def __init__(self, c_in, c_out): super(Align, self).__init__() self.c_in = c_in self.c_out = c_out self.align_conv = nn.Conv2d(in_channels=c_in, out_channels=c_out, kernel_size=(1, 1)) def forward(self, x): if self.c_in > self.c_out: x = self.align_conv(x) elif self.c_in < self.c_out: batch_size, _, timestep, n_vertex = x.shape x = torch.cat([x, torch.zeros([batch_size, self.c_out - self.c_in, timestep, n_vertex]).to(x)], dim=1) else: x = x return x class CausalConv1d(nn.Conv1d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True): if enable_padding == True: self.__padding = (kernel_size - 1) * dilation else: self.__padding = 0 super(CausalConv1d, self).__init__(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=self.__padding, dilation=dilation, groups=groups, bias=bias) def forward(self, input): result = super(CausalConv1d, self).forward(input) if self.__padding != 0: return result[: , : , : -self.__padding] return result class CausalConv2d(nn.Conv2d): def __init__(self, in_channels, out_channels, kernel_size, stride=1, enable_padding=False, dilation=1, groups=1, bias=True): kernel_size = nn.modules.utils._pair(kernel_size) stride = nn.modules.utils._pair(stride) dilation = nn.modules.utils._pair(dilation) if enable_padding == True: self.__padding = [int((kernel_size[i] - 1) * dilation[i]) for i in range(len(kernel_size))] else: self.__padding = 0 self.left_padding = nn.modules.utils._pair(self.__padding) super(CausalConv2d, self).__init__(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, groups=groups, bias=bias) def forward(self, input): if self.__padding != 0: input = F.pad(input, (self.left_padding[1], 0, self.left_padding[0], 0)) result = super(CausalConv2d, self).forward(input) return result class TemporalConvLayer(nn.Module): # Temporal Convolution Layer (GLU) # # |--------------------------------| * residual connection * # | | # | |--->--- casualconv2d ----- + -------| # -------|----| ⊙ ------> # |--->--- casualconv2d --- sigmoid ---| # #param x: tensor, [bs, c_in, ts, n_vertex] def __init__(self, Kt, c_in, c_out, n_vertex, act_func): super(TemporalConvLayer, self).__init__() self.Kt = Kt self.c_in = c_in self.c_out = c_out self.n_vertex = n_vertex self.align = Align(c_in, c_out) if act_func == 'glu' or act_func == 'gtu': self.causal_conv = CausalConv2d(in_channels=c_in, out_channels=2 * c_out, kernel_size=(Kt, 1), enable_padding=False, dilation=1) else: self.causal_conv = CausalConv2d(in_channels=c_in, out_channels=c_out, kernel_size=(Kt, 1), enable_padding=False, dilation=1) self.relu = nn.ReLU() self.silu = nn.SiLU() self.act_func = act_func def forward(self, x): x_in = self.align(x)[:, :, self.Kt - 1:, :] x_causal_conv = self.causal_conv(x) if self.act_func == 'glu' or self.act_func == 'gtu': x_p = x_causal_conv[:, : self.c_out, :, :] x_q = x_causal_conv[:, -self.c_out:, :, :] if self.act_func == 'glu': # Explanation of Gated Linear Units (GLU): # The concept of GLU was first introduced in the paper # "Language Modeling with Gated Convolutional Networks". # URL: https://arxiv.org/abs/1612.08083 # In the GLU operation, the input tensor X is divided into two tensors, X_a and X_b, # along a specific dimension. # In PyTorch, GLU is computed as the element-wise multiplication of X_a and sigmoid(X_b). # More information can be found here: https://pytorch.org/docs/master/nn.functional.html#torch.nn.functional.glu # The provided code snippet, (x_p + x_in) ⊙ sigmoid(x_q), is an example of GLU operation. x = torch.mul((x_p + x_in), torch.sigmoid(x_q)) else: # tanh(x_p + x_in) ⊙ sigmoid(x_q) x = torch.mul(torch.tanh(x_p + x_in), torch.sigmoid(x_q)) elif self.act_func == 'relu': x = self.relu(x_causal_conv + x_in) elif self.act_func == 'silu': x = self.silu(x_causal_conv + x_in) else: raise NotImplementedError(f'ERROR: The activation function {self.act_func} is not implemented.') return x class ChebGraphConv(nn.Module): def __init__(self, c_in, c_out, Ks, gso, bias): super(ChebGraphConv, self).__init__() self.c_in = c_in self.c_out = c_out self.Ks = Ks self.gso = gso self.weight = nn.Parameter(torch.FloatTensor(Ks, c_in, c_out)) if bias: self.bias = nn.Parameter(torch.FloatTensor(c_out)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(self.bias, -bound, bound) def forward(self, x): #bs, c_in, ts, n_vertex = x.shape x = torch.permute(x, (0, 2, 3, 1)) if self.Ks - 1 < 0: raise ValueError(f'ERROR: the graph convolution kernel size Ks has to be a positive integer, but received {self.Ks}.') elif self.Ks - 1 == 0: x_0 = x x_list = [x_0] elif self.Ks - 1 == 1: x_0 = x x_1 = torch.einsum('hi,btij->bthj', self.gso, x) x_list = [x_0, x_1] elif self.Ks - 1 >= 2: x_0 = x x_1 = torch.einsum('hi,btij->bthj', self.gso, x) x_list = [x_0, x_1] for k in range(2, self.Ks): x_list.append(torch.einsum('hi,btij->bthj', 2 * self.gso, x_list[k - 1]) - x_list[k - 2]) x = torch.stack(x_list, dim=2) cheb_graph_conv = torch.einsum('btkhi,kij->bthj', x, self.weight) if self.bias is not None: cheb_graph_conv = torch.add(cheb_graph_conv, self.bias) else: cheb_graph_conv = cheb_graph_conv return cheb_graph_conv class GraphConv(nn.Module): def __init__(self, c_in, c_out, gso, bias): super(GraphConv, self).__init__() self.c_in = c_in self.c_out = c_out self.gso = gso self.weight = nn.Parameter(torch.FloatTensor(c_in, c_out)) if bias: self.bias = nn.Parameter(torch.FloatTensor(c_out)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): init.kaiming_uniform_(self.weight, a=math.sqrt(5)) if self.bias is not None: fan_in, _ = init._calculate_fan_in_and_fan_out(self.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 init.uniform_(self.bias, -bound, bound) def forward(self, x): #bs, c_in, ts, n_vertex = x.shape x = torch.permute(x, (0, 2, 3, 1)) first_mul = torch.einsum('hi,btij->bthj', self.gso, x) second_mul = torch.einsum('bthi,ij->bthj', first_mul, self.weight) if self.bias is not None: graph_conv = torch.add(second_mul, self.bias) else: graph_conv = second_mul return graph_conv class GraphConvLayer(nn.Module): def __init__(self, graph_conv_type, c_in, c_out, Ks, gso, bias): super(GraphConvLayer, self).__init__() self.graph_conv_type = graph_conv_type self.c_in = c_in self.c_out = c_out self.align = Align(c_in, c_out) self.Ks = Ks self.gso = gso if self.graph_conv_type == 'cheb_graph_conv': self.cheb_graph_conv = ChebGraphConv(c_out, c_out, Ks, gso, bias) elif self.graph_conv_type == 'graph_conv': self.graph_conv = GraphConv(c_out, c_out, gso, bias) def forward(self, x): x_gc_in = self.align(x) if self.graph_conv_type == 'cheb_graph_conv': x_gc = self.cheb_graph_conv(x_gc_in) elif self.graph_conv_type == 'graph_conv': x_gc = self.graph_conv(x_gc_in) x_gc = x_gc.permute(0, 3, 1, 2) x_gc_out = torch.add(x_gc, x_gc_in) return x_gc_out class STConvBlock(nn.Module): # STConv Block contains 'TGTND' structure # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (ChebGraphConv or GraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout def __init__(self, Kt, Ks, n_vertex, last_block_channel, channels, act_func, graph_conv_type, gso, bias, droprate): super(STConvBlock, self).__init__() self.tmp_conv1 = TemporalConvLayer(Kt, last_block_channel, channels[0], n_vertex, act_func) self.graph_conv = GraphConvLayer(graph_conv_type, channels[0], channels[1], Ks, gso, bias) self.tmp_conv2 = TemporalConvLayer(Kt, channels[1], channels[2], n_vertex, act_func) self.tc2_ln = nn.LayerNorm([n_vertex, channels[2]], eps=1e-12) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=droprate) def forward(self, x): x = self.tmp_conv1(x) x = self.graph_conv(x) x = self.relu(x) x = self.tmp_conv2(x) x = self.tc2_ln(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.dropout(x) return x class OutputBlock(nn.Module): # Output block contains 'TNFF' structure # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # F: Fully-Connected Layer # F: Fully-Connected Layer def __init__(self, Ko, last_block_channel, channels, end_channel, n_vertex, act_func, bias, droprate): super(OutputBlock, self).__init__() self.tmp_conv1 = TemporalConvLayer(Ko, last_block_channel, channels[0], n_vertex, act_func) self.fc1 = nn.Linear(in_features=channels[0], out_features=channels[1], bias=bias) self.fc2 = nn.Linear(in_features=channels[1], out_features=end_channel, bias=bias) self.tc1_ln = nn.LayerNorm([n_vertex, channels[0]], eps=1e-12) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=droprate) def forward(self, x): x = self.tmp_conv1(x) x = self.tc1_ln(x.permute(0, 2, 3, 1)) x = self.fc1(x) x = self.relu(x) x = self.dropout(x) x = self.fc2(x).permute(0, 3, 1, 2) return x