import torch import torch.nn as nn import torch.nn.functional as F # 简化的masked_mae_loss函数 def masked_mae_loss(mask_value=5.0): def loss_fn(pred, target): mask = (target != mask_value).float() mae = F.l1_loss(pred * mask, target * mask, reduction="sum") return mae / (mask.sum() + 1e-8) return loss_fn # 简化的数据增强函数 def aug_topology(sim_mx, graph, percent=0.1): return graph def aug_traffic(sim_mx, data, percent=0.1): return data class STEncoder(nn.Module): def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate): super(STEncoder, self).__init__() self.num_nodes = num_nodes self.input_length = input_length # 简化的编码器 - 修复输入通道数 self.conv1 = nn.Conv2d( blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2) ) self.conv2 = nn.Conv2d( blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2) ) self.dropout = nn.Dropout(droprate) # 临时的相似度矩阵 self.s_sim_mx = torch.randn(num_nodes, num_nodes) self.t_sim_mx = torch.randn(input_length, input_length) def forward(self, x, graph): # x: (batch_size, num_nodes, seq_len, features) batch_size, num_nodes, seq_len, features = x.shape # 调整维度 x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len) # 确保输入通道数正确 if x.shape[1] != 2: # 如果不是2个通道,需要调整 if x.shape[1] == 1: x = x.repeat(1, 2, 1, 1) # 复制到2个通道 else: x = x[:, :2, :, :] # 取前2个通道 # 卷积操作 x = F.relu(self.conv1(x)) x = self.dropout(x) x = F.relu(self.conv2(x)) # 调整回原维度 x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features) return x class MLP(nn.Module): def __init__(self, input_dim, output_dim): super(MLP, self).__init__() self.fc = nn.Linear(input_dim, output_dim) def forward(self, x): return self.fc(x) class TemporalHeteroModel(nn.Module): def __init__(self, d_model, batch_size, num_nodes, device): super(TemporalHeteroModel, self).__init__() self.fc = nn.Linear(d_model, 1) def forward(self, z1, z2): return F.mse_loss(self.fc(z1), self.fc(z2)) class SpatialHeteroModel(nn.Module): def __init__(self, d_model, nmb_prototype, batch_size, shm_temp): super(SpatialHeteroModel, self).__init__() self.fc = nn.Linear(d_model, 1) def forward(self, z1, z2): return F.mse_loss(self.fc(z1), self.fc(z2)) class STSSL(nn.Module): def __init__(self, args): super(STSSL, self).__init__() # spatial temporal encoder self.encoder = STEncoder( Kt=3, Ks=3, blocks=[ [2, int(args["d_model"] // 2), args["d_model"]], [args["d_model"], int(args["d_model"] // 2), args["d_model"]], ], input_length=args["input_length"], num_nodes=args["num_nodes"], droprate=args["dropout"], ) # traffic flow prediction branch self.mlp = MLP(args["d_model"], args["d_output"]) # temporal heterogenrity modeling branch self.thm = TemporalHeteroModel( args["d_model"], args["batch_size"], args["num_nodes"], args["device"] ) # spatial heterogenrity modeling branch self.shm = SpatialHeteroModel( args["d_model"], args["nmb_prototype"], args["batch_size"], args["shm_temp"] ) self.mae = masked_mae_loss(mask_value=5.0) self.args = args def forward(self, view1, graph): repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v s_sim_mx = self.fetch_spatial_sim() graph2 = aug_topology(s_sim_mx, graph, percent=self.args["percent"] * 2) t_sim_mx = self.fetch_temporal_sim() view2 = aug_traffic(t_sim_mx, view1, percent=self.args["percent"]) repr2 = self.encoder(view2, graph2) return repr1, repr2 def fetch_spatial_sim(self): """ Fetch the region similarity matrix generated by region embedding. Note this can be called only when spatial_sim is True. :return sim_mx: tensor, similarity matrix, (v, v) """ return self.encoder.s_sim_mx.cpu() def fetch_temporal_sim(self): return self.encoder.t_sim_mx.cpu() def predict(self, z1, z2): """Predicting future traffic flow. :param z1, z2 (tensor): shape nvc :return: nlvc, l=1, c=2 """ return self.mlp(z1) def loss(self, z1, z2, y_true, scaler, loss_weights): l1 = self.pred_loss(z1, z2, y_true, scaler) sep_loss = [l1.item()] loss = loss_weights[0] * l1 l2 = self.temporal_loss(z1, z2) sep_loss.append(l2.item()) loss += loss_weights[1] * l2 l3 = self.spatial_loss(z1, z2) sep_loss.append(l3.item()) loss += loss_weights[2] * l3 return loss, sep_loss def pred_loss(self, z1, z2, y_true, scaler): y_pred = scaler.inverse_transform(self.predict(z1, z2)) y_true = scaler.inverse_transform(y_true) loss = self.args["yita"] * self.mae(y_pred[..., 0], y_true[..., 0]) + ( 1 - self.args["yita"] ) * self.mae(y_pred[..., 1], y_true[..., 1]) return loss def temporal_loss(self, z1, z2): return self.thm(z1, z2) def spatial_loss(self, z1, z2): return self.shm(z1, z2)