157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
# 简化的masked_mae_loss函数
|
||
def masked_mae_loss(mask_value=5.0):
|
||
def loss_fn(pred, target):
|
||
mask = (target != mask_value).float()
|
||
mae = F.l1_loss(pred * mask, target * mask, reduction='sum')
|
||
return mae / (mask.sum() + 1e-8)
|
||
return loss_fn
|
||
|
||
# 简化的数据增强函数
|
||
def aug_topology(sim_mx, graph, percent=0.1):
|
||
return graph
|
||
|
||
def aug_traffic(sim_mx, data, percent=0.1):
|
||
return data
|
||
|
||
class STEncoder(nn.Module):
|
||
def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate):
|
||
super(STEncoder, self).__init__()
|
||
self.num_nodes = num_nodes
|
||
self.input_length = input_length
|
||
|
||
# 简化的编码器 - 修复输入通道数
|
||
self.conv1 = nn.Conv2d(blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||
self.conv2 = nn.Conv2d(blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt//2, Ks//2))
|
||
self.dropout = nn.Dropout(droprate)
|
||
|
||
# 临时的相似度矩阵
|
||
self.s_sim_mx = torch.randn(num_nodes, num_nodes)
|
||
self.t_sim_mx = torch.randn(input_length, input_length)
|
||
|
||
def forward(self, x, graph):
|
||
# x: (batch_size, num_nodes, seq_len, features)
|
||
batch_size, num_nodes, seq_len, features = x.shape
|
||
|
||
# 调整维度
|
||
x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len)
|
||
|
||
# 确保输入通道数正确
|
||
if x.shape[1] != 2: # 如果不是2个通道,需要调整
|
||
if x.shape[1] == 1:
|
||
x = x.repeat(1, 2, 1, 1) # 复制到2个通道
|
||
else:
|
||
x = x[:, :2, :, :] # 取前2个通道
|
||
|
||
# 卷积操作
|
||
x = F.relu(self.conv1(x))
|
||
x = self.dropout(x)
|
||
x = F.relu(self.conv2(x))
|
||
|
||
# 调整回原维度
|
||
x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features)
|
||
|
||
return x
|
||
|
||
class MLP(nn.Module):
|
||
def __init__(self, input_dim, output_dim):
|
||
super(MLP, self).__init__()
|
||
self.fc = nn.Linear(input_dim, output_dim)
|
||
|
||
def forward(self, x):
|
||
return self.fc(x)
|
||
|
||
class TemporalHeteroModel(nn.Module):
|
||
def __init__(self, d_model, batch_size, num_nodes, device):
|
||
super(TemporalHeteroModel, self).__init__()
|
||
self.fc = nn.Linear(d_model, 1)
|
||
|
||
def forward(self, z1, z2):
|
||
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||
|
||
class SpatialHeteroModel(nn.Module):
|
||
def __init__(self, d_model, nmb_prototype, batch_size, shm_temp):
|
||
super(SpatialHeteroModel, self).__init__()
|
||
self.fc = nn.Linear(d_model, 1)
|
||
|
||
def forward(self, z1, z2):
|
||
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||
|
||
class STSSL(nn.Module):
|
||
def __init__(self, args):
|
||
super(STSSL, self).__init__()
|
||
# spatial temporal encoder
|
||
self.encoder = STEncoder(Kt=3, Ks=3, blocks=[[2, int(args['d_model']//2), args['d_model']], [args['d_model'], int(args['d_model']//2), args['d_model']]],
|
||
input_length=args['input_length'], num_nodes=args['num_nodes'], droprate=args['dropout'])
|
||
|
||
# traffic flow prediction branch
|
||
self.mlp = MLP(args['d_model'], args['d_output'])
|
||
# temporal heterogenrity modeling branch
|
||
self.thm = TemporalHeteroModel(args['d_model'], args['batch_size'], args['num_nodes'], args['device'])
|
||
# spatial heterogenrity modeling branch
|
||
self.shm = SpatialHeteroModel(args['d_model'], args['nmb_prototype'], args['batch_size'], args['shm_temp'])
|
||
self.mae = masked_mae_loss(mask_value=5.0)
|
||
self.args = args
|
||
|
||
def forward(self, view1, graph):
|
||
repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v
|
||
|
||
s_sim_mx = self.fetch_spatial_sim()
|
||
graph2 = aug_topology(s_sim_mx, graph, percent=self.args['percent']*2)
|
||
|
||
t_sim_mx = self.fetch_temporal_sim()
|
||
view2 = aug_traffic(t_sim_mx, view1, percent=self.args['percent'])
|
||
|
||
repr2 = self.encoder(view2, graph2)
|
||
return repr1, repr2
|
||
|
||
def fetch_spatial_sim(self):
|
||
"""
|
||
Fetch the region similarity matrix generated by region embedding.
|
||
Note this can be called only when spatial_sim is True.
|
||
:return sim_mx: tensor, similarity matrix, (v, v)
|
||
"""
|
||
return self.encoder.s_sim_mx.cpu()
|
||
|
||
def fetch_temporal_sim(self):
|
||
return self.encoder.t_sim_mx.cpu()
|
||
|
||
def predict(self, z1, z2):
|
||
'''Predicting future traffic flow.
|
||
:param z1, z2 (tensor): shape nvc
|
||
:return: nlvc, l=1, c=2
|
||
'''
|
||
return self.mlp(z1)
|
||
|
||
def loss(self, z1, z2, y_true, scaler, loss_weights):
|
||
l1 = self.pred_loss(z1, z2, y_true, scaler)
|
||
sep_loss = [l1.item()]
|
||
loss = loss_weights[0] * l1
|
||
|
||
l2 = self.temporal_loss(z1, z2)
|
||
sep_loss.append(l2.item())
|
||
loss += loss_weights[1] * l2
|
||
|
||
l3 = self.spatial_loss(z1, z2)
|
||
sep_loss.append(l3.item())
|
||
loss += loss_weights[2] * l3
|
||
return loss, sep_loss
|
||
|
||
def pred_loss(self, z1, z2, y_true, scaler):
|
||
y_pred = scaler.inverse_transform(self.predict(z1, z2))
|
||
y_true = scaler.inverse_transform(y_true)
|
||
|
||
loss = self.args['yita'] * self.mae(y_pred[..., 0], y_true[..., 0]) + \
|
||
(1 - self.args['yita']) * self.mae(y_pred[..., 1], y_true[..., 1])
|
||
return loss
|
||
|
||
def temporal_loss(self, z1, z2):
|
||
return self.thm(z1, z2)
|
||
|
||
def spatial_loss(self, z1, z2):
|
||
return self.shm(z1, z2)
|
||
|