TrafficWheel/model/STGCN/models.py

134 lines
5.3 KiB
Python
Executable File

import torch.nn as nn
from model.STGCN import layers
from data.get_adj import get_gso
class STGCNChebGraphConv(nn.Module):
# STGCNChebGraphConv contains 'TGTND TGTND TNFF' structure
# ChebGraphConv is the graph convolution from ChebyNet.
# Using the Chebyshev polynomials of the first kind as a graph filter.
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebGraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (ChebGraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normalization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
def __init__(self, args):
super(STGCNChebGraphConv, self).__init__()
gso = get_gso(args)
Ko = args['n_his'] - (args['Kt'] - 1) * 2 * args['stblock_num']
blocks = [[1]]
for l in range(args['stblock_num']):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([12])
modules = []
for l in range(len(blocks) - 3):
modules.append(layers.STConvBlock(args['Kt'], args['Ks'], args['num_nodes'], blocks[l][-1], blocks[l + 1],
args['act_func'], args['graph_conv_type'], gso,
args['enable_bias'], args['droprate']))
self.st_blocks = nn.Sequential(*modules)
Ko = args['n_his'] - (len(blocks) - 3) * 2 * (args['Kt'] - 1)
self.Ko = Ko
if self.Ko > 1:
self.output = layers.OutputBlock(Ko, blocks[-3][-1], blocks[-2], blocks[-1][0],
args['num_nodes'], args['act_func'], args['enable_bias'], args['droprate'])
elif self.Ko == 0:
self.fc1 = nn.Linear(in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args['enable_bias'])
self.fc2 = nn.Linear(in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args['enable_bias'])
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=args['droprate'])
def forward(self, x):
x = x[..., 0:1]
x = x.permute(0, 3, 1, 2)
x = self.st_blocks(x) # 64,12,307,3
if self.Ko > 1:
x = self.output(x)
elif self.Ko == 0:
x = self.fc1(x.permute(0, 2, 3, 1))
x = self.relu(x)
x = self.fc2(x).permute(0, 3, 1, 2)
x = x.permute(0, 1, 3, 2)
return x
class STGCNGraphConv(nn.Module):
# STGCNGraphConv contains 'TGTND TGTND TNFF' structure
# GraphConv is the graph convolution from GCN.
# GraphConv is not the first-order ChebConv, because the renormalization trick is adopted.
# Be careful about over-smoothing.
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (GraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# G: Graph Convolution Layer (GraphConv)
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normolization
# D: Dropout
# T: Gated Temporal Convolution Layer (GLU or GTU)
# N: Layer Normalization
# F: Fully-Connected Layer
# F: Fully-Connected Layer
def __init__(self, args):
super(STGCNGraphConv, self).__init__()
gso = get_gso(args)
Ko = args['n_his'] - (args['Kt'] - 1) * 2 * args['stblock_num']
blocks = [[1]]
for l in range(args['stblock_num']):
blocks.append([64, 16, 64])
if Ko == 0:
blocks.append([128])
elif Ko > 0:
blocks.append([128, 128])
blocks.append([1])
modules = []
for l in range(len(blocks) - 3):
modules.append(layers.STConvBlock(args['Kt'], args['Ks'], args['num_nodes'], blocks[l][-1], blocks[l + 1],
args['act_func'], args['graph_conv_type'], gso,
args['enable_bias'], args['droprate']))
self.st_blocks = nn.Sequential(*modules)
Ko = args['n_his'] - (len(blocks) - 3) * 2 * (args['Kt'] - 1)
self.Ko = Ko
if self.Ko > 1:
self.output = layers.OutputBlock(Ko, blocks[-3][-1], blocks[-2], blocks[-1][0], args['num_nodes'],
args['act_func'], args['enable_bias'], args['droprate'])
elif self.Ko == 0:
self.fc1 = nn.Linear(in_features=blocks[-3][-1], out_features=blocks[-2][0], bias=args['enable_bias'])
self.fc2 = nn.Linear(in_features=blocks[-2][0], out_features=blocks[-1][0], bias=args['enable_bias'])
self.relu = nn.ReLU()
self.do = nn.Dropout(p=args['droprate'])
def forward(self, x):
x = self.st_blocks(x)
if self.Ko > 1:
x = self.output(x)
elif self.Ko == 0:
x = self.fc1(x.permute(0, 2, 3, 1))
x = self.relu(x)
x = self.fc2(x).permute(0, 3, 1, 2)
return x