TrafficWheel/model/STGCN/models.py

188 lines
6.0 KiB
Python
Executable File

import torch.nn as nn
from model.STGCN import layers
from data.get_adj import get_gso
class STGCNChebGraphConv(nn.Module):
# STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure
# ChebGraphConv is the graph convolution from ChebyNet.
# Using the Chebyshev polynomials of the first kind as a graph filter.
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebGraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebGraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normalization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
def __init__(self, args):
super(STGCNChebGraphConv, self).__init__()
gso = get_gso(args)
Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"]
blocks = [[1]]
for l in range(args["stblock_num"]):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([12])
modules = []
for l in range(len(blocks) - 3):
modules.append(
layers.STConvBlock(
args["Kt"],
args["Ks"],
args["num_nodes"],
blocks[l][-1],
blocks[l + 1],
args["act_func"],
args["graph_conv_type"],
gso,
args["enable_bias"],
args["droprate"],
)
)
self.st_blocks = nn.Sequential(*modules)
Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1)
self.Ko = Ko
if self.Ko > 1:
self.output = layers.OutputBlock(
Ko,
blocks[-3][-1],
blocks[-2],
blocks[-1][0],
args["num_nodes"],
args["act_func"],
args["enable_bias"],
args["droprate"],
)
elif self.Ko == 0:
self.fc1 = nn.Linear(
in_features=blocks[-3][-1],
out_features=blocks[-2][0],
bias=args["enable_bias"],
)
self.fc2 = nn.Linear(
in_features=blocks[-2][0],
out_features=blocks[-1][0],
bias=args["enable_bias"],
)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=args["droprate"])
def forward(self, x):
x = x[..., 0:1]
x = x.permute(0, 3, 1, 2)
x = self.st_blocks(x) # 64,12,307,3
if self.Ko > 1:
x = self.output(x)
elif self.Ko == 0:
x = self.fc1(x.permute(0, 2, 3, 1))
x = self.relu(x)
x = self.fc2(x).permute(0, 3, 1, 2)
x = x.permute(0, 1, 3, 2)
return x
class STGCNGraphConv(nn.Module):
# STGCNGraphConv contains 'TGTND TGTND TNFF' structure
# GraphConv is the graph convolution from GCN.
# GraphConv is not the first-order ChebConv, because the renormalization trick is adopted.
# Be careful about over-smoothing.
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (GraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (GraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normalization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
def __init__(self, args):
super(STGCNGraphConv, self).__init__()
gso = get_gso(args)
Ko = args["n_his"] - (args["Kt"] - 1) * 2 * args["stblock_num"]
blocks = [[1]]
for l in range(args["stblock_num"]):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([1])
modules = []
for l in range(len(blocks) - 3):
modules.append(
layers.STConvBlock(
args["Kt"],
args["Ks"],
args["num_nodes"],
blocks[l][-1],
blocks[l + 1],
args["act_func"],
args["graph_conv_type"],
gso,
args["enable_bias"],
args["droprate"],
)
)
self.st_blocks = nn.Sequential(*modules)
Ko = args["n_his"] - (len(blocks) - 3) * 2 * (args["Kt"] - 1)
self.Ko = Ko
if self.Ko > 1:
self.output = layers.OutputBlock(
Ko,
blocks[-3][-1],
blocks[-2],
blocks[-1][0],
args["num_nodes"],
args["act_func"],
args["enable_bias"],
args["droprate"],
)
elif self.Ko == 0:
self.fc1 = nn.Linear(
in_features=blocks[-3][-1],
out_features=blocks[-2][0],
bias=args["enable_bias"],
)
self.fc2 = nn.Linear(
in_features=blocks[-2][0],
out_features=blocks[-1][0],
bias=args["enable_bias"],
)
self.relu = nn.ReLU()
self.do = nn.Dropout(p=args["droprate"])
def forward(self, x):
x = self.st_blocks(x)
if self.Ko > 1:
x = self.output(x)
elif self.Ko == 0:
x = self.fc1(x.permute(0, 2, 3, 1))
x = self.relu(x)
x = self.fc2(x).permute(0, 3, 1, 2)
return x