188 lines
6.0 KiB
Python
Executable File
188 lines
6.0 KiB
Python
Executable File
import torch.nn as nn
|
|
|
|
from model.STGCN import layers
|
|
from utils.get_adj import get_gso
|
|
|
|
|
|
class STGCNChebGraphConv(nn.Module):
|
|
# STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure
|
|
# ChebGraphConv is the graph convolution from ChebyNet.
|
|
# Using the Chebyshev polynomials of the first kind as a graph filter.
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# G: Graph Convolution Layer (ChebGraphConv)
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# D: Dropout
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# G: Graph Convolution Layer (ChebGraphConv)
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# D: Dropout
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normalization
|
|
# F: Fully-Connected Layer
|
|
# F: Fully-Connected Layer
|
|
|
|
def __init__(self, args):
|
|
super(STGCNChebGraphConv, self).__init__()
|
|
gso = get_gso(args)
|
|
Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"]
|
|
blocks = [[1]]
|
|
for l in range(args["stblock_num"]):
|
|
blocks.append([64, 16, 64])
|
|
if Ko == 0:
|
|
blocks.append([128])
|
|
elif Ko > 0:
|
|
blocks.append([128, 128])
|
|
blocks.append([12])
|
|
modules = []
|
|
for l in range(len(blocks) - 3):
|
|
modules.append(
|
|
layers.STConvBlock(
|
|
args["Kt"],
|
|
args["Ks"],
|
|
args["num_nodes"],
|
|
blocks[l][-1],
|
|
blocks[l + 1],
|
|
args["act_func"],
|
|
args["graph_conv_type"],
|
|
gso,
|
|
args["enable_bias"],
|
|
args["droprate"],
|
|
)
|
|
)
|
|
self.st_blocks = nn.Sequential(*modules)
|
|
Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1)
|
|
self.Ko = Ko
|
|
if self.Ko > 1:
|
|
self.output = layers.OutputBlock(
|
|
Ko,
|
|
blocks[-3][-1],
|
|
blocks[-2],
|
|
blocks[-1][0],
|
|
args["num_nodes"],
|
|
args["act_func"],
|
|
args["enable_bias"],
|
|
args["droprate"],
|
|
)
|
|
elif self.Ko == 0:
|
|
self.fc1 = nn.Linear(
|
|
in_features=blocks[-3][-1],
|
|
out_features=blocks[-2][0],
|
|
bias=args["enable_bias"],
|
|
)
|
|
self.fc2 = nn.Linear(
|
|
in_features=blocks[-2][0],
|
|
out_features=blocks[-1][0],
|
|
bias=args["enable_bias"],
|
|
)
|
|
self.relu = nn.ReLU()
|
|
self.dropout = nn.Dropout(p=args["droprate"])
|
|
|
|
def forward(self, x):
|
|
x = x[..., 0:1]
|
|
x = x.permute(0, 3, 1, 2)
|
|
x = self.st_blocks(x) # 64,12,307,3
|
|
if self.Ko > 1:
|
|
x = self.output(x)
|
|
elif self.Ko == 0:
|
|
x = self.fc1(x.permute(0, 2, 3, 1))
|
|
x = self.relu(x)
|
|
x = self.fc2(x).permute(0, 3, 1, 2)
|
|
|
|
x = x.permute(0, 1, 3, 2)
|
|
return x
|
|
|
|
|
|
class STGCNGraphConv(nn.Module):
|
|
# STGCNGraphConv contains 'TGTND TGTND TNFF' structure
|
|
# GraphConv is the graph convolution from GCN.
|
|
# GraphConv is not the first-order ChebConv, because the renormalization trick is adopted.
|
|
# Be careful about over-smoothing.
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# G: Graph Convolution Layer (GraphConv)
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# D: Dropout
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# G: Graph Convolution Layer (GraphConv)
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normolization
|
|
# D: Dropout
|
|
|
|
# T: Gated Temporal Convolution Layer (GLU or GTU)
|
|
# N: Layer Normalization
|
|
# F: Fully-Connected Layer
|
|
# F: Fully-Connected Layer
|
|
|
|
def __init__(self, args):
|
|
super(STGCNGraphConv, self).__init__()
|
|
gso = get_gso(args)
|
|
Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"]
|
|
blocks = [[1]]
|
|
for l in range(args["stblock_num"]):
|
|
blocks.append([64, 16, 64])
|
|
if Ko == 0:
|
|
blocks.append([128])
|
|
elif Ko > 0:
|
|
blocks.append([128, 128])
|
|
blocks.append([1])
|
|
modules = []
|
|
for l in range(len(blocks) - 3):
|
|
modules.append(
|
|
layers.STConvBlock(
|
|
args["Kt"],
|
|
args["Ks"],
|
|
args["num_nodes"],
|
|
blocks[l][-1],
|
|
blocks[l + 1],
|
|
args["act_func"],
|
|
args["graph_conv_type"],
|
|
gso,
|
|
args["enable_bias"],
|
|
args["droprate"],
|
|
)
|
|
)
|
|
self.st_blocks = nn.Sequential(*modules)
|
|
Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1)
|
|
self.Ko = Ko
|
|
if self.Ko > 1:
|
|
self.output = layers.OutputBlock(
|
|
Ko,
|
|
blocks[-3][-1],
|
|
blocks[-2],
|
|
blocks[-1][0],
|
|
args["num_nodes"],
|
|
args["act_func"],
|
|
args["enable_bias"],
|
|
args["droprate"],
|
|
)
|
|
elif self.Ko == 0:
|
|
self.fc1 = nn.Linear(
|
|
in_features=blocks[-3][-1],
|
|
out_features=blocks[-2][0],
|
|
bias=args["enable_bias"],
|
|
)
|
|
self.fc2 = nn.Linear(
|
|
in_features=blocks[-2][0],
|
|
out_features=blocks[-1][0],
|
|
bias=args["enable_bias"],
|
|
)
|
|
self.relu = nn.ReLU()
|
|
self.do = nn.Dropout(p=args["droprate"])
|
|
|
|
def forward(self, x):
|
|
x = self.st_blocks(x)
|
|
if self.Ko > 1:
|
|
x = self.output(x)
|
|
elif self.Ko == 0:
|
|
x = self.fc1(x.permute(0, 2, 3, 1))
|
|
x = self.relu(x)
|
|
x = self.fc2(x).permute(0, 3, 1, 2)
|
|
return x
|