TrafficWheel/model/ST_SSL/ST_SSL.py

129 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from data.get_adj import get_gso
class STSSLModel(nn.Module):
def __init__(self, args):
super(STSSLModel, self).__init__()
# 设置默认参数
if 'd_model' not in args:
args['d_model'] = 64
if 'd_output' not in args:
args['d_output'] = args['output_dim']
if 'input_length' not in args:
args['input_length'] = args['n_his']
if 'dropout' not in args:
args['dropout'] = 0.1
if 'nmb_prototype' not in args:
args['nmb_prototype'] = 10
if 'batch_size' not in args:
args['batch_size'] = 64
if 'shm_temp' not in args:
args['shm_temp'] = 0.1
if 'yita' not in args:
args['yita'] = 0.5
if 'percent' not in args:
args['percent'] = 0.1
if 'device' not in args:
args['device'] = 'cpu'
if 'gso_type' not in args:
args['gso_type'] = 'sym_norm_lap'
if 'graph_conv_type' not in args:
args['graph_conv_type'] = 'cheb_graph_conv'
# 保存参数
self.args = args
self.num_nodes = args['num_nodes']
self.input_dim = args['input_dim']
self.output_dim = args['output_dim']
self.horizon = args['horizon']
self.d_model = args['d_model']
# 获取邻接矩阵
self.gso = get_gso(args)
# 时间嵌入
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model']))
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model']))
# 节点嵌入
self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
# 编码器 - 使用1个输入通道
self.encoder = STEncoder(
Kt=3, Ks=3,
input_dim=1, # 只使用第一个通道
hidden_dim=args['d_model'],
input_length=args['input_length'],
num_nodes=args['num_nodes'],
droprate=args['dropout']
)
# 预测头
self.predictor = nn.Linear(args['d_model'], args['output_dim'])
# 初始化参数
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.node_emb_u)
nn.init.xavier_uniform_(self.node_emb_d)
nn.init.xavier_uniform_(self.T_i_D_emb)
nn.init.xavier_uniform_(self.D_i_W_emb)
def forward(self, x):
# x shape: (batch_size, seq_len, num_nodes, features)
# 按照DDGCRN的模式只使用第一个通道
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
# 编码
encoded = self.encoder(x, self.gso)
# 预测
# 取最后一个时间步的输出进行预测
last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model)
# 预测未来horizon个时间步
predictions = []
for t in range(self.horizon):
pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim)
predictions.append(pred)
# 堆叠预测结果
output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim)
return output
class STEncoder(nn.Module):
def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate):
super(STEncoder, self).__init__()
self.num_nodes = num_nodes
self.input_length = input_length
# 简化的时空编码器 - 使用1个输入通道
self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.dropout = nn.Dropout(droprate)
def forward(self, x, graph):
# x: (batch_size, seq_len, num_nodes, features)
batch_size, seq_len, num_nodes, features = x.shape
# 调整维度
x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes)
# 卷积操作
x = F.relu(self.conv1(x))
x = self.dropout(x)
x = F.relu(self.conv2(x))
# 调整回原维度
x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features)
return x