129 lines
4.5 KiB
Python
129 lines
4.5 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from data.get_adj import get_gso
|
||
|
||
class STSSLModel(nn.Module):
|
||
def __init__(self, args):
|
||
super(STSSLModel, self).__init__()
|
||
|
||
# 设置默认参数
|
||
if 'd_model' not in args:
|
||
args['d_model'] = 64
|
||
if 'd_output' not in args:
|
||
args['d_output'] = args['output_dim']
|
||
if 'input_length' not in args:
|
||
args['input_length'] = args['n_his']
|
||
if 'dropout' not in args:
|
||
args['dropout'] = 0.1
|
||
if 'nmb_prototype' not in args:
|
||
args['nmb_prototype'] = 10
|
||
if 'batch_size' not in args:
|
||
args['batch_size'] = 64
|
||
if 'shm_temp' not in args:
|
||
args['shm_temp'] = 0.1
|
||
if 'yita' not in args:
|
||
args['yita'] = 0.5
|
||
if 'percent' not in args:
|
||
args['percent'] = 0.1
|
||
if 'device' not in args:
|
||
args['device'] = 'cpu'
|
||
if 'gso_type' not in args:
|
||
args['gso_type'] = 'sym_norm_lap'
|
||
if 'graph_conv_type' not in args:
|
||
args['graph_conv_type'] = 'cheb_graph_conv'
|
||
|
||
# 保存参数
|
||
self.args = args
|
||
self.num_nodes = args['num_nodes']
|
||
self.input_dim = args['input_dim']
|
||
self.output_dim = args['output_dim']
|
||
self.horizon = args['horizon']
|
||
self.d_model = args['d_model']
|
||
|
||
# 获取邻接矩阵
|
||
self.gso = get_gso(args)
|
||
|
||
# 时间嵌入
|
||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model']))
|
||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model']))
|
||
|
||
# 节点嵌入
|
||
self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
|
||
self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model']))
|
||
|
||
# 编码器 - 使用1个输入通道
|
||
self.encoder = STEncoder(
|
||
Kt=3, Ks=3,
|
||
input_dim=1, # 只使用第一个通道
|
||
hidden_dim=args['d_model'],
|
||
input_length=args['input_length'],
|
||
num_nodes=args['num_nodes'],
|
||
droprate=args['dropout']
|
||
)
|
||
|
||
# 预测头
|
||
self.predictor = nn.Linear(args['d_model'], args['output_dim'])
|
||
|
||
# 初始化参数
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
nn.init.xavier_uniform_(self.node_emb_u)
|
||
nn.init.xavier_uniform_(self.node_emb_d)
|
||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||
|
||
def forward(self, x):
|
||
# x shape: (batch_size, seq_len, num_nodes, features)
|
||
# 按照DDGCRN的模式,只使用第一个通道
|
||
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||
|
||
# 编码
|
||
encoded = self.encoder(x, self.gso)
|
||
|
||
# 预测
|
||
# 取最后一个时间步的输出进行预测
|
||
last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model)
|
||
|
||
# 预测未来horizon个时间步
|
||
predictions = []
|
||
for t in range(self.horizon):
|
||
pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim)
|
||
predictions.append(pred)
|
||
|
||
# 堆叠预测结果
|
||
output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim)
|
||
|
||
return output
|
||
|
||
|
||
class STEncoder(nn.Module):
|
||
def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate):
|
||
super(STEncoder, self).__init__()
|
||
self.num_nodes = num_nodes
|
||
self.input_length = input_length
|
||
|
||
# 简化的时空编码器 - 使用1个输入通道
|
||
self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||
self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||
self.dropout = nn.Dropout(droprate)
|
||
|
||
def forward(self, x, graph):
|
||
# x: (batch_size, seq_len, num_nodes, features)
|
||
batch_size, seq_len, num_nodes, features = x.shape
|
||
|
||
# 调整维度
|
||
x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes)
|
||
|
||
# 卷积操作
|
||
x = F.relu(self.conv1(x))
|
||
x = self.dropout(x)
|
||
x = F.relu(self.conv2(x))
|
||
|
||
# 调整回原维度
|
||
x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features)
|
||
|
||
return x
|
||
|