import torch.nn as nn from model.STGCN import layers from data.get_adj import get_gso class STGCNChebGraphConv(nn.Module): # STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure # ChebGraphConv is the graph convolution from ChebyNet. # Using the Chebyshev polynomials of the first kind as a graph filter. # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (ChebGraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (ChebGraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normalization # F: Fully-Connected Layer # F: Fully-Connected Layer def __init__(self, args): super(STGCNChebGraphConv, self).__init__() gso = get_gso(args) Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"] blocks = [[1]] for l in range(args["stblock_num"]): blocks.append([64, 16, 64]) if Ko == 0: blocks.append([128]) elif Ko > 0: blocks.append([128, 128]) blocks.append([12]) modules = [] for l in range(len(blocks) - 3): modules.append( layers.STConvBlock( args["Kt"], args["Ks"], args["num_nodes"], blocks[l][-1], blocks[l + 1], args["act_func"], args["graph_conv_type"], gso, args["enable_bias"], args["droprate"], ) ) self.st_blocks = nn.Sequential(*modules) Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1) self.Ko = Ko if self.Ko > 1: self.output = layers.OutputBlock( Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], args["num_nodes"], args["act_func"], args["enable_bias"], args["droprate"], ) elif self.Ko == 0: self.fc1 = nn.Linear( in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args["enable_bias"], ) self.fc2 = nn.Linear( in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args["enable_bias"], ) self.relu = nn.ReLU() self.dropout = nn.Dropout(p=args["droprate"]) def forward(self, x): x = x[..., 0:1] x = x.permute(0, 3, 1, 2) x = self.st_blocks(x) # 64,12,307,3 if self.Ko > 1: x = self.output(x) elif self.Ko == 0: x = self.fc1(x.permute(0, 2, 3, 1)) x = self.relu(x) x = self.fc2(x).permute(0, 3, 1, 2) x = x.permute(0, 1, 3, 2) return x class STGCNGraphConv(nn.Module): # STGCNGraphConv contains 'TGTND TGTND TNFF' structure # GraphConv is the graph convolution from GCN. # GraphConv is not the first-order ChebConv, because the renormalization trick is adopted. # Be careful about over-smoothing. # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (GraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # G: Graph Convolution Layer (GraphConv) # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normolization # D: Dropout # T: Gated Temporal Convolution Layer (GLU or GTU) # N: Layer Normalization # F: Fully-Connected Layer # F: Fully-Connected Layer def __init__(self, args): super(STGCNGraphConv, self).__init__() gso = get_gso(args) Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"] blocks = [[1]] for l in range(args["stblock_num"]): blocks.append([64, 16, 64]) if Ko == 0: blocks.append([128]) elif Ko > 0: blocks.append([128, 128]) blocks.append([1]) modules = [] for l in range(len(blocks) - 3): modules.append( layers.STConvBlock( args["Kt"], args["Ks"], args["num_nodes"], blocks[l][-1], blocks[l + 1], args["act_func"], args["graph_conv_type"], gso, args["enable_bias"], args["droprate"], ) ) self.st_blocks = nn.Sequential(*modules) Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1) self.Ko = Ko if self.Ko > 1: self.output = layers.OutputBlock( Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], args["num_nodes"], args["act_func"], args["enable_bias"], args["droprate"], ) elif self.Ko == 0: self.fc1 = nn.Linear( in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args["enable_bias"], ) self.fc2 = nn.Linear( in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args["enable_bias"], ) self.relu = nn.ReLU() self.do = nn.Dropout(p=args["droprate"]) def forward(self, x): x = self.st_blocks(x) if self.Ko > 1: x = self.output(x) elif self.Ko == 0: x = self.fc1(x.permute(0, 2, 3, 1)) x = self.relu(x) x = self.fc2(x).permute(0, 3, 1, 2) return x