141 lines
4.4 KiB
Python
141 lines
4.4 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
from utils.get_adj import get_gso
|
||
|
||
|
||
class STSSLModel(nn.Module):
|
||
def __init__(self, args):
|
||
super(STSSLModel, self).__init__()
|
||
|
||
# 设置默认参数
|
||
if "d_model" not in args:
|
||
args["d_model"] = 64
|
||
if "d_output" not in args:
|
||
args["d_output"] = args["output_dim"]
|
||
if "input_length" not in args:
|
||
args["input_length"] = args["n_his"]
|
||
if "dropout" not in args:
|
||
args["dropout"] = 0.1
|
||
if "nmb_prototype" not in args:
|
||
args["nmb_prototype"] = 10
|
||
if "batch_size" not in args:
|
||
args["batch_size"] = 64
|
||
if "shm_temp" not in args:
|
||
args["shm_temp"] = 0.1
|
||
if "yita" not in args:
|
||
args["yita"] = 0.5
|
||
if "percent" not in args:
|
||
args["percent"] = 0.1
|
||
if "device" not in args:
|
||
args["device"] = "cpu"
|
||
if "gso_type" not in args:
|
||
args["gso_type"] = "sym_norm_lap"
|
||
if "graph_conv_type" not in args:
|
||
args["graph_conv_type"] = "cheb_graph_conv"
|
||
|
||
# 保存参数
|
||
self.args = args
|
||
self.num_nodes = args["num_nodes"]
|
||
self.input_dim = args["input_dim"]
|
||
self.output_dim = args["output_dim"]
|
||
self.horizon = args["horizon"]
|
||
self.d_model = args["d_model"]
|
||
|
||
# 获取邻接矩阵
|
||
self.gso = get_gso(args)
|
||
|
||
# 时间嵌入
|
||
self.T_i_D_emb = nn.Parameter(torch.empty(288, args["d_model"]))
|
||
self.D_i_W_emb = nn.Parameter(torch.empty(7, args["d_model"]))
|
||
|
||
# 节点嵌入
|
||
self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args["d_model"]))
|
||
self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args["d_model"]))
|
||
|
||
# 编码器 - 使用1个输入通道
|
||
self.encoder = STEncoder(
|
||
Kt=3,
|
||
Ks=3,
|
||
input_dim=1, # 只使用第一个通道
|
||
hidden_dim=args["d_model"],
|
||
input_length=args["input_length"],
|
||
num_nodes=args["num_nodes"],
|
||
droprate=args["dropout"],
|
||
)
|
||
|
||
# 预测头
|
||
self.predictor = nn.Linear(args["d_model"], args["output_dim"])
|
||
|
||
# 初始化参数
|
||
self.reset_parameters()
|
||
|
||
def reset_parameters(self):
|
||
nn.init.xavier_uniform_(self.node_emb_u)
|
||
nn.init.xavier_uniform_(self.node_emb_d)
|
||
nn.init.xavier_uniform_(self.T_i_D_emb)
|
||
nn.init.xavier_uniform_(self.D_i_W_emb)
|
||
|
||
def forward(self, x):
|
||
# x shape: (batch_size, seq_len, num_nodes, features)
|
||
# 按照DDGCRN的模式,只使用第一个通道
|
||
x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1)
|
||
|
||
# 编码
|
||
encoded = self.encoder(x, self.gso)
|
||
|
||
# 预测
|
||
# 取最后一个时间步的输出进行预测
|
||
last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model)
|
||
|
||
# 预测未来horizon个时间步
|
||
predictions = []
|
||
for t in range(self.horizon):
|
||
pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim)
|
||
predictions.append(pred)
|
||
|
||
# 堆叠预测结果
|
||
output = torch.stack(
|
||
predictions, dim=1
|
||
) # (batch_size, horizon, num_nodes, output_dim)
|
||
|
||
return output
|
||
|
||
|
||
class STEncoder(nn.Module):
|
||
def __init__(
|
||
self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate
|
||
):
|
||
super(STEncoder, self).__init__()
|
||
self.num_nodes = num_nodes
|
||
self.input_length = input_length
|
||
|
||
# 简化的时空编码器 - 使用1个输入通道
|
||
self.conv1 = nn.Conv2d(
|
||
input_dim, hidden_dim // 2, kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2)
|
||
)
|
||
self.conv2 = nn.Conv2d(
|
||
hidden_dim // 2,
|
||
hidden_dim,
|
||
kernel_size=(Kt, Ks),
|
||
padding=(Kt // 2, Ks // 2),
|
||
)
|
||
self.dropout = nn.Dropout(droprate)
|
||
|
||
def forward(self, x, graph):
|
||
# x: (batch_size, seq_len, num_nodes, features)
|
||
batch_size, seq_len, num_nodes, features = x.shape
|
||
|
||
# 调整维度
|
||
x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes)
|
||
|
||
# 卷积操作
|
||
x = F.relu(self.conv1(x))
|
||
x = self.dropout(x)
|
||
x = F.relu(self.conv2(x))
|
||
|
||
# 调整回原维度
|
||
x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features)
|
||
|
||
return x
|