import torch.nn as nn from model.STGCN import layers from data.get_adj import get_gso class STGCNChebGraphConv(nn.Module): # STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure # ChebGraphConv is the graph convolution from ChebyNet. # Using the Chebyshev polynomials of the first kind as a graph filter. # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (ChebGraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (ChebGraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normalization # F: Fully-Connected Layer # F: Fully-Connected Layer def __init__(self, args): super(STGCNChebGraphConv, self).__init__() gso = get_gso(args) Ko = args['n_his'] - (args['Kt'] - 1) * 2 * args['stblock_num'] blocks = [[1]] for l in range(args['stblock_num']): blocks.append([64, 16, 64]) if Ko == 0: blocks.append([128]) elif Ko > 0: blocks.append([128, 128]) blocks.append([12]) modules = [] for l in range(len(blocks) - 3): modules.append(layers.STConvBlock(args['Kt'], args['Ks'], args['num_nodes'], blocks[l][-1], blocks[l + 1], args['act_func'], args['graph_conv_type'], gso, args['enable_bias'], args['droprate'])) self.st_blocks = nn.Sequential(*modules) Ko = args['n_his'] - (len(blocks) - 3) * 2 * (args['Kt'] - 1) self.Ko = Ko if self.Ko > 1: self.output = layers.OutputBlock(Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], args['num_nodes'], args['act_func'], args['enable_bias'], args['droprate']) elif self.Ko == 0: self.fc1 = nn.Linear(in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args['enable_bias']) self.fc2 = nn.Linear(in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args['enable_bias']) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=args['droprate']) def forward(self, x): x = x[..., 0:1] x = x.permute(0, 3, 1, 2) x = self.st_blocks(x) # 64,12,307,3 if self.Ko > 1: x = self.output(x) elif self.Ko == 0: x = self.fc1(x.permute(0, 2, 3, 1)) x = self.relu(x) x = self.fc2(x).permute(0, 3, 1, 2) x = x.permute(0, 1, 3, 2) return x class STGCNGraphConv(nn.Module): # STGCNGraphConv contains 'TGTND TGTND TNFF' structure # GraphConv is the graph convolution from GCN. # GraphConv is not the first-order ChebConv, because the renormalization trick is adopted. # Be careful about over-smoothing. # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (GraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (GraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normalization # F: Fully-Connected Layer # F: Fully-Connected Layer def __init__(self, args): super(STGCNGraphConv, self).__init__() gso = get_gso(args) Ko = args['n_his'] - (args['Kt'] - 1) * 2 * args['stblock_num'] blocks = [[1]] for l in range(args['stblock_num']): blocks.append([64, 16, 64]) if Ko == 0: blocks.append([128]) elif Ko > 0: blocks.append([128, 128]) blocks.append([1]) modules = [] for l in range(len(blocks) - 3): modules.append(layers.STConvBlock(args['Kt'], args['Ks'], args['num_nodes'], blocks[l][-1], blocks[l + 1], args['act_func'], args['graph_conv_type'], gso, args['enable_bias'], args['droprate'])) self.st_blocks = nn.Sequential(*modules) Ko = args['n_his'] - (len(blocks) - 3) * 2 * (args['Kt'] - 1) self.Ko = Ko if self.Ko > 1: self.output = layers.OutputBlock(Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], args['num_nodes'], args['act_func'], args['enable_bias'], args['droprate']) elif self.Ko == 0: self.fc1 = nn.Linear(in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args['enable_bias']) self.fc2 = nn.Linear(in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args['enable_bias']) self.relu = nn.ReLU() self.do = nn.Dropout(p=args['droprate']) def forward(self, x): x = self.st_blocks(x) if self.Ko > 1: x = self.output(x) elif self.Ko == 0: x = self.fc1(x.permute(0, 2, 3, 1)) x = self.relu(x) x = self.fc2(x).permute(0, 3, 1, 2) return x