import torch import torch.nn as nn import torch.nn.functional as F from data.get_adj import get_gso class STSSLModel(nn.Module): def __init__(self, args): super(STSSLModel, self).__init__() # 设置默认参数 if "d_model" not in args: args["d_model"] = 64 if "d_output" not in args: args["d_output"] = args["output_dim"] if "input_length" not in args: args["input_length"] = args["n_his"] if "dropout" not in args: args["dropout"] = 0.1 if "nmb_prototype" not in args: args["nmb_prototype"] = 10 if "batch_size" not in args: args["batch_size"] = 64 if "shm_temp" not in args: args["shm_temp"] = 0.1 if "yita" not in args: args["yita"] = 0.5 if "percent" not in args: args["percent"] = 0.1 if "device" not in args: args["device"] = "cpu" if "gso_type" not in args: args["gso_type"] = "sym_norm_lap" if "graph_conv_type" not in args: args["graph_conv_type"] = "cheb_graph_conv" # 保存参数 self.args = args self.num_nodes = args["num_nodes"] self.input_dim = args["input_dim"] self.output_dim = args["output_dim"] self.horizon = args["horizon"] self.d_model = args["d_model"] # 获取邻接矩阵 self.gso = get_gso(args) # 时间嵌入 self.T_i_D_emb = nn.Parameter(torch.empty(288, args["d_model"])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args["d_model"])) # 节点嵌入 self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args["d_model"])) self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args["d_model"])) # 编码器 - 使用1个输入通道 self.encoder = STEncoder( Kt=3, Ks=3, input_dim=1, # 只使用第一个通道 hidden_dim=args["d_model"], input_length=args["input_length"], num_nodes=args["num_nodes"], droprate=args["dropout"], ) # 预测头 self.predictor = nn.Linear(args["d_model"], args["output_dim"]) # 初始化参数 self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.node_emb_u) nn.init.xavier_uniform_(self.node_emb_d) nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.D_i_W_emb) def forward(self, x): # x shape: (batch_size, seq_len, num_nodes, features) # 按照DDGCRN的模式,只使用第一个通道 x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) # 编码 encoded = self.encoder(x, self.gso) # 预测 # 取最后一个时间步的输出进行预测 last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model) # 预测未来horizon个时间步 predictions = [] for t in range(self.horizon): pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim) predictions.append(pred) # 堆叠预测结果 output = torch.stack( predictions, dim=1 ) # (batch_size, horizon, num_nodes, output_dim) return output class STEncoder(nn.Module): def __init__( self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate ): super(STEncoder, self).__init__() self.num_nodes = num_nodes self.input_length = input_length # 简化的时空编码器 - 使用1个输入通道 self.conv1 = nn.Conv2d( input_dim, hidden_dim // 2, kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2) ) self.conv2 = nn.Conv2d( hidden_dim // 2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2), ) self.dropout = nn.Dropout(droprate) def forward(self, x, graph): # x: (batch_size, seq_len, num_nodes, features) batch_size, seq_len, num_nodes, features = x.shape # 调整维度 x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes) # 卷积操作 x = F.relu(self.conv1(x)) x = self.dropout(x) x = F.relu(self.conv2(x)) # 调整回原维度 x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features) return x