TrafficWheel/model/STSGCN/STSGCN.py

417 lines
13 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# from _typeshed import Self
import torch
import torch.nn.functional as F
import torch.nn as nn
from model.STSGCN.get_adj import get_adj
class gcn_operation(nn.Module):
def __init__(self, adj, in_dim, out_dim, num_vertices, activation='GLU'):
"""
图卷积模块
:param adj: 邻接图
:param in_dim: 输入维度
:param out_dim: 输出维度
:param num_vertices: 节点数量
:param activation: 激活方式 {'relu', 'GLU'}
"""
super(gcn_operation, self).__init__()
self.adj = adj
self.in_dim = in_dim
self.out_dim = out_dim
self.num_vertices = num_vertices
self.activation = activation
assert self.activation in {'GLU', 'relu'}
if self.activation == 'GLU':
self.FC = nn.Linear(self.in_dim, 2 * self.out_dim, bias=True)
else:
self.FC = nn.Linear(self.in_dim, self.out_dim, bias=True)
def forward(self, x, mask=None):
"""
:param x: (3*N, B, Cin)
:param mask:(3*N, 3*N)
:return: (3*N, B, Cout)
"""
adj = self.adj
if mask is not None:
adj = adj.to(mask.device) * mask
x = torch.einsum('nm, mbc->nbc', adj.to(x.device), x) # 4*N, B, Cin
if self.activation == 'GLU':
lhs_rhs = self.FC(x) # 4*N, B, 2*Cout
lhs, rhs = torch.split(lhs_rhs, self.out_dim, dim=-1) # 4*N, B, Cout
out = lhs * torch.sigmoid(rhs)
del lhs, rhs, lhs_rhs
return out
elif self.activation == 'relu':
return torch.relu(self.FC(x)) # 3*N, B, Cout
class STSGCM(nn.Module):
def __init__(self, adj, in_dim, out_dims, num_of_vertices, activation='GLU'):
"""
:param adj: 邻接矩阵
:param in_dim: 输入维度
:param out_dims: list 各个图卷积的输出维度
:param num_of_vertices: 节点数量
:param activation: 激活方式 {'relu', 'GLU'}
"""
super(STSGCM, self).__init__()
self.adj = adj
self.in_dim = in_dim
self.out_dims = out_dims
self.num_of_vertices = num_of_vertices
self.activation = activation
self.gcn_operations = nn.ModuleList()
self.gcn_operations.append(
gcn_operation(
adj=self.adj,
in_dim=self.in_dim,
out_dim=self.out_dims[0],
num_vertices=self.num_of_vertices,
activation=self.activation
)
)
for i in range(1, len(self.out_dims)):
self.gcn_operations.append(
gcn_operation(
adj=self.adj,
in_dim=self.out_dims[i-1],
out_dim=self.out_dims[i],
num_vertices=self.num_of_vertices,
activation=self.activation
)
)
def forward(self, x, mask=None):
"""
:param x: (3N, B, Cin)
:param mask: (3N, 3N)
:return: (N, B, Cout)
"""
need_concat = []
for i in range(len(self.out_dims)):
x = self.gcn_operations[i](x, mask)
need_concat.append(x)
# shape of each element is (1, N, B, Cout)
need_concat = [
torch.unsqueeze(
h[self.num_of_vertices: 2 * self.num_of_vertices], dim=0
) for h in need_concat
]
out = torch.max(torch.cat(need_concat, dim=0), dim=0).values # (N, B, Cout)
del need_concat
return out
class STSGCL(nn.Module):
def __init__(self,
adj,
history,
num_of_vertices,
in_dim,
out_dims,
strides=4,
activation='GLU',
temporal_emb=True,
spatial_emb=True):
"""
:param adj: 邻接矩阵
:param history: 输入时间步长
:param in_dim: 输入维度
:param out_dims: list 各个图卷积的输出维度
:param strides: 滑动窗口步长local时空图使用几个时间步构建的默认为3
:param num_of_vertices: 节点数量
:param activation: 激活方式 {'relu', 'GLU'}
:param temporal_emb: 加入时间位置嵌入向量
:param spatial_emb: 加入空间位置嵌入向量
"""
super(STSGCL, self).__init__()
self.adj = adj
self.strides = strides
self.history = history
self.in_dim = in_dim
self.out_dims = out_dims
self.num_of_vertices = num_of_vertices
self.activation = activation
self.temporal_emb = temporal_emb
self.spatial_emb = spatial_emb
self.STSGCMS = nn.ModuleList()
for i in range(self.history - self.strides + 1):
self.STSGCMS.append(
STSGCM(
adj=self.adj,
in_dim=self.in_dim,
out_dims=self.out_dims,
num_of_vertices=self.num_of_vertices,
activation=self.activation
)
)
if self.temporal_emb:
self.temporal_embedding = nn.Parameter(torch.FloatTensor(1, self.history, 1, self.in_dim))
# 1, T, 1, Cin
if self.spatial_emb:
self.spatial_embedding = nn.Parameter(torch.FloatTensor(1, 1, self.num_of_vertices, self.in_dim))
# 1, 1, N, Cin
self.reset()
def reset(self):
if self.temporal_emb:
nn.init.xavier_normal_(self.temporal_embedding, gain=0.0003)
if self.spatial_emb:
nn.init.xavier_normal_(self.spatial_embedding, gain=0.0003)
def forward(self, x, mask=None):
"""
:param x: B, T, N, Cin
:param mask: (N, N)
:return: B, T-3, N, Cout
"""
if self.temporal_emb:
x = x + self.temporal_embedding
if self.spatial_emb:
x = x + self.spatial_embedding
need_concat = []
batch_size = x.shape[0]
for i in range(self.history - self.strides + 1):
t = x[:, i: i+self.strides, :, :] # (B, 4, N, Cin)
t = torch.reshape(t, shape=[batch_size, self.strides * self.num_of_vertices, self.in_dim])
# (B, 3*N, Cin)
t = self.STSGCMS[i](t.permute(1, 0, 2), mask) # (3*N, B, Cin) -> (N, B, Cout)
t = torch.unsqueeze(t.permute(1, 0, 2), dim=1) # (N, B, Cout) -> (B, N, Cout) ->(B, 1, N, Cout)
need_concat.append(t)
out = torch.cat(need_concat, dim=1) # (B, T-2, N, Cout)
del need_concat, batch_size
return out
class output_layer(nn.Module):
def __init__(self, num_of_vertices, history, in_dim, out_dim,
hidden_dim=128, horizon=12):
"""
预测层注意在作者的实验中是对每一个预测时间step做处理的也即他会令horizon=1
:param num_of_vertices:节点数
:param history:输入时间步长
:param in_dim: 输入维度
:param hidden_dim:中间层维度
:param horizon:预测时间步长
"""
super(output_layer, self).__init__()
self.num_of_vertices = num_of_vertices
self.history = history
self.in_dim = in_dim
self.out_dim = out_dim
self.hidden_dim = hidden_dim
self.horizon = horizon
self.FC1 = nn.Linear(self.in_dim * self.history, self.hidden_dim, bias=True)
self.FC2 = nn.Linear(self.hidden_dim, self.horizon * self.out_dim, bias=True)
def forward(self, x):
"""
:param x: (B, Tin, N, Cin)
:return: (B, Tout, N)
"""
batch_size = x.shape[0]
x = x.permute(0, 2, 1, 3) # B, N, Tin, Cin
out1 = torch.relu(self.FC1(x.reshape(batch_size, self.num_of_vertices, -1)))
# (B, N, Tin, Cin) -> (B, N, Tin * Cin) -> (B, N, hidden)
out2 = self.FC2(out1) # (B, N, hidden) -> (B, N, horizon * 2)
# out2 = out2.reshape(batch_size, self.num_of_vertices, self.horizon, self.out_dim)
del out1, batch_size
return out2.permute(0, 2, 1) # B, horizon, N
# return out2.permute(0, 2, 1) # B, horizon, N
class STSGCN(nn.Module):
def __init__(self, args):
"""
:param adj: local时空间矩阵
:param history:输入时间步长
:param num_of_vertices:节点数量
:param in_dim:输入维度
:param hidden_dims: lists, 中间各STSGCL层的卷积操作维度
:param first_layer_embedding_size: 第一层输入层的维度
:param out_layer_dim: 输出模块中间层维度
:param activation: 激活函数 {relu, GlU}
:param use_mask: 是否使用mask矩阵对adj进行优化
:param temporal_emb:是否使用时间嵌入向量
:param spatial_emb:是否使用空间嵌入向量
:param horizon:预测时间步长
:param strides:滑动窗口步长local时空图使用几个时间步构建的默认为4
"""
super(STSGCN, self).__init__()
self.adj = get_adj(args)
self.config = args
self.num_of_vertices = self.config.get("num_nodes", None)
self.hidden_dims = self.config.get("hidden_dims", None)
self.out_layer_dim = self.config.get("out_layer_dim", None)
self.activation = self.config.get("activation", "GLU")
self.use_mask = self.config.get("use_mask")
self.temporal_emb = self.config.get("temporal_emb", True)
self.spatial_emb = self.config.get("spatial_emb", True)
self.horizon = self.config.get("horizon", 12)
self.strides = self.config.get("strides", 3)
history = self.config.get("history", 12)
in_dim = self.config.get("input_dim", 1)
out_dim = self.config.get("output_dim", 1)
first_layer_embedding_size = self.config.get("first_layer_embedding_size", None)
self.First_FC = nn.Linear(in_dim, first_layer_embedding_size, bias=True)
self.STSGCLS = nn.ModuleList()
self.STSGCLS.append(
STSGCL(
adj=self.adj,
history=history,
num_of_vertices=self.num_of_vertices,
in_dim=first_layer_embedding_size,
out_dims=self.hidden_dims[0],
strides=self.strides,
activation=self.activation,
temporal_emb=self.temporal_emb,
spatial_emb=self.spatial_emb
)
)
in_dim = self.hidden_dims[0][-1]
history -= (self.strides - 1)
for idx, hidden_list in enumerate(self.hidden_dims):
if idx == 0:
continue
self.STSGCLS.append(
STSGCL(
adj=self.adj,
history=history,
num_of_vertices=self.num_of_vertices,
in_dim=in_dim,
out_dims=hidden_list,
strides=self.strides,
activation=self.activation,
temporal_emb=self.temporal_emb,
spatial_emb=self.spatial_emb
)
)
history -= (self.strides - 1)
in_dim = hidden_list[-1]
self.predictLayer = nn.ModuleList()
for t in range(self.horizon):
self.predictLayer.append(
output_layer(
num_of_vertices=self.num_of_vertices,
history=history,
in_dim=in_dim,
out_dim=out_dim,
hidden_dim=self.out_layer_dim,
horizon=1
)
)
if self.use_mask:
mask = torch.zeros_like(self.adj)
mask[self.adj != 0] = self.adj[self.adj != 0]
self.mask = nn.Parameter(mask)
else:
self.mask = None
def forward(self, x):
"""
:param x: B, Tin, N, Cin)
:return: B, Tout, N
"""
x = x[..., 0:1]
x = torch.relu(self.First_FC(x)) # B, Tin, N, Cin
#print(1)
for model in self.STSGCLS:
x = model(x, self.mask)
# (B, T - 8, N, Cout)
#print(2)
need_concat = []
for i in range(self.horizon):
out_step = self.predictLayer[i](x) # (B, 1, N, 2)
need_concat.append(out_step)
#print(3)
out = torch.cat(need_concat, dim=1) # B, Tout, N, 2
del need_concat
return out.unsqueeze(dim=-1)
if __name__ == '__main__':
import torch
# 定义测试用例的参数
args = {
"input_dim": 1,
"output_dim": 1,
"window": 12,
"hidden_dims": [[32, 32, 32], [32, 32, 32], [32, 32, 32]],
"first_layer_embedding_size": 64,
"out_layer_dim": 64,
"activation": "GLU",
"mask": True,
"temporal_emb": True,
"spatial_emb": True,
"horizon": 12,
"strides": 4,
"num_nodes": 883 # 根据数据集的节点数量
}
# 创建模型
model = STFGNN(args)
# 创建随机输入数据
batch_size = 64
input_data = torch.randn(batch_size, args["window"], args["num_nodes"], args["input_dim"])
# 模型推理
output = model(input_data)
# 打印输出的形状
print("Output shape:", output.shape)