TrafficWheel/model/ST_SSL/models.py

157 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
# 简化的masked_mae_loss函数
def masked_mae_loss(mask_value=5.0):
def loss_fn(pred, target):
mask = (target != mask_value).float()
mae = F.l1_loss(pred * mask, target * mask, reduction='sum')
return mae / (mask.sum() + 1e-8)
return loss_fn
# 简化的数据增强函数
def aug_topology(sim_mx, graph, percent=0.1):
return graph
def aug_traffic(sim_mx, data, percent=0.1):
return data
class STEncoder(nn.Module):
def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate):
super(STEncoder, self).__init__()
self.num_nodes = num_nodes
self.input_length = input_length
# 简化的编码器 - 修复输入通道数
self.conv1 = nn.Conv2d(blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.conv2 = nn.Conv2d(blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
self.dropout = nn.Dropout(droprate)
# 临时的相似度矩阵
self.s_sim_mx = torch.randn(num_nodes, num_nodes)
self.t_sim_mx = torch.randn(input_length, input_length)
def forward(self, x, graph):
# x: (batch_size, num_nodes, seq_len, features)
batch_size, num_nodes, seq_len, features = x.shape
# 调整维度
x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len)
# 确保输入通道数正确
if x.shape[1] != 2: # 如果不是2个通道需要调整
if x.shape[1] == 1:
x = x.repeat(1, 2, 1, 1) # 复制到2个通道
else:
x = x[:, :2, :, :] # 取前2个通道
# 卷积操作
x = F.relu(self.conv1(x))
x = self.dropout(x)
x = F.relu(self.conv2(x))
# 调整回原维度
x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features)
return x
class MLP(nn.Module):
def __init__(self, input_dim, output_dim):
super(MLP, self).__init__()
self.fc = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.fc(x)
class TemporalHeteroModel(nn.Module):
def __init__(self, d_model, batch_size, num_nodes, device):
super(TemporalHeteroModel, self).__init__()
self.fc = nn.Linear(d_model, 1)
def forward(self, z1, z2):
return F.mse_loss(self.fc(z1), self.fc(z2))
class SpatialHeteroModel(nn.Module):
def __init__(self, d_model, nmb_prototype, batch_size, shm_temp):
super(SpatialHeteroModel, self).__init__()
self.fc = nn.Linear(d_model, 1)
def forward(self, z1, z2):
return F.mse_loss(self.fc(z1), self.fc(z2))
class STSSL(nn.Module):
def __init__(self, args):
super(STSSL, self).__init__()
# spatial temporal encoder
self.encoder = STEncoder(Kt=3, Ks=3, blocks=[[2, int(args['d_model']//2), args['d_model']], [args['d_model'], int(args['d_model']//2), args['d_model']]],
input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout'])
# traffic flow prediction branch
self.mlp = MLP(args['d_model'], args['d_output'])
# temporal heterogenrity modeling branch
self.thm = TemporalHeteroModel(args['d_model'], args['batch_size'], args['num_nodes'], args['device'])
# spatial heterogenrity modeling branch
self.shm = SpatialHeteroModel(args['d_model'], args['nmb_prototype'], args['batch_size'], args['shm_temp'])
self.mae = masked_mae_loss(mask_value=5.0)
self.args = args
def forward(self, view1, graph):
repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v
s_sim_mx = self.fetch_spatial_sim()
graph2 = aug_topology(s_sim_mx, graph, percent=self.args['percent']*2)
t_sim_mx = self.fetch_temporal_sim()
view2 = aug_traffic(t_sim_mx, view1, percent=self.args['percent'])
repr2 = self.encoder(view2, graph2)
return repr1, repr2
def fetch_spatial_sim(self):
"""
Fetch the region similarity matrix generated by region embedding.
Note this can be called only when spatial_sim is True.
:return sim_mx: tensor, similarity matrix, (v, v)
"""
return self.encoder.s_sim_mx.cpu()
def fetch_temporal_sim(self):
return self.encoder.t_sim_mx.cpu()
def predict(self, z1, z2):
'''Predicting future traffic flow.
:param z1, z2 (tensor): shape nvc
:return: nlvc, l=1, c=2
'''
return self.mlp(z1)
def loss(self, z1, z2, y_true, scaler, loss_weights):
l1 = self.pred_loss(z1, z2, y_true, scaler)
sep_loss = [l1.item()]
loss = loss_weights[0] * l1
l2 = self.temporal_loss(z1, z2)
sep_loss.append(l2.item())
loss += loss_weights[1] * l2
l3 = self.spatial_loss(z1, z2)
sep_loss.append(l3.item())
loss += loss_weights[2] * l3
return loss, sep_loss
def pred_loss(self, z1, z2, y_true, scaler):
y_pred = scaler.inverse_transform(self.predict(z1, z2))
y_true = scaler.inverse_transform(y_true)
loss = self.args['yita'] * self.mae(y_pred[..., 0], y_true[..., 0]) + \
(1 - self.args['yita']) * self.mae(y_pred[..., 1], y_true[..., 1])
return loss
def temporal_loss(self, z1, z2):
return self.thm(z1, z2)
def spatial_loss(self, z1, z2):
return self.shm(z1, z2)