import torch import torch.nn as nn import torch.nn.functional as F from data.get_adj import get_gso class STSSLModel(nn.Module): def __init__(self, args): super(STSSLModel, self).__init__() # 设置默认参数 if 'd_model' not in args: args['d_model'] = 64 if 'd_output' not in args: args['d_output'] = args['output_dim'] if 'input_length' not in args: args['input_length'] = args['n_his'] if 'dropout' not in args: args['dropout'] = 0.1 if 'nmb_prototype' not in args: args['nmb_prototype'] = 10 if 'batch_size' not in args: args['batch_size'] = 64 if 'shm_temp' not in args: args['shm_temp'] = 0.1 if 'yita' not in args: args['yita'] = 0.5 if 'percent' not in args: args['percent'] = 0.1 if 'device' not in args: args['device'] = 'cpu' if 'gso_type' not in args: args['gso_type'] = 'sym_norm_lap' if 'graph_conv_type' not in args: args['graph_conv_type'] = 'cheb_graph_conv' # 保存参数 self.args = args self.num_nodes = args['num_nodes'] self.input_dim = args['input_dim'] self.output_dim = args['output_dim'] self.horizon = args['horizon'] self.d_model = args['d_model'] # 获取邻接矩阵 self.gso = get_gso(args) # 时间嵌入 self.T_i_D_emb = nn.Parameter(torch.empty(288, args['d_model'])) self.D_i_W_emb = nn.Parameter(torch.empty(7, args['d_model'])) # 节点嵌入 self.node_emb_u = nn.Parameter(torch.randn(self.num_nodes, args['d_model'])) self.node_emb_d = nn.Parameter(torch.randn(self.num_nodes, args['d_model'])) # 编码器 - 使用1个输入通道 self.encoder = STEncoder( Kt=3, Ks=3, input_dim=1, # 只使用第一个通道 hidden_dim=args['d_model'], input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout'] ) # 预测头 self.predictor = nn.Linear(args['d_model'], args['output_dim']) # 初始化参数 self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.node_emb_u) nn.init.xavier_uniform_(self.node_emb_d) nn.init.xavier_uniform_(self.T_i_D_emb) nn.init.xavier_uniform_(self.D_i_W_emb) def forward(self, x): # x shape: (batch_size, seq_len, num_nodes, features) # 按照DDGCRN的模式,只使用第一个通道 x = x[..., 0].unsqueeze(-1) # (batch_size, seq_len, num_nodes, 1) # 编码 encoded = self.encoder(x, self.gso) # 预测 # 取最后一个时间步的输出进行预测 last_hidden = encoded[:, -1, :, :] # (batch_size, num_nodes, d_model) # 预测未来horizon个时间步 predictions = [] for t in range(self.horizon): pred = self.predictor(last_hidden) # (batch_size, num_nodes, output_dim) predictions.append(pred) # 堆叠预测结果 output = torch.stack(predictions, dim=1) # (batch_size, horizon, num_nodes, output_dim) return output class STEncoder(nn.Module): def __init__(self, Kt, Ks, input_dim, hidden_dim, input_length, num_nodes, droprate): super(STEncoder, self).__init__() self.num_nodes = num_nodes self.input_length = input_length # 简化的时空编码器 - 使用1个输入通道 self.conv1 = nn.Conv2d(input_dim, hidden_dim//2, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) self.conv2 = nn.Conv2d(hidden_dim//2, hidden_dim, kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2)) self.dropout = nn.Dropout(droprate) def forward(self, x, graph): # x: (batch_size, seq_len, num_nodes, features) batch_size, seq_len, num_nodes, features = x.shape # 调整维度 x = x.permute(0, 3, 1, 2) # (batch_size, features, seq_len, num_nodes) # 卷积操作 x = F.relu(self.conv1(x)) x = self.dropout(x) x = F.relu(self.conv2(x)) # 调整回原维度 x = x.permute(0, 2, 3, 1) # (batch_size, seq_len, num_nodes, features) return x