183 lines
5.6 KiB
Python
183 lines
5.6 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
|
||
# 简化的masked_mae_loss函数
|
||
def masked_mae_loss(mask_value=5.0):
|
||
def loss_fn(pred, target):
|
||
mask = (target != mask_value).float()
|
||
mae = F.l1_loss(pred * mask, target * mask, reduction="sum")
|
||
return mae / (mask.sum() + 1e-8)
|
||
|
||
return loss_fn
|
||
|
||
|
||
# 简化的数据增强函数
|
||
def aug_topology(sim_mx, graph, percent=0.1):
|
||
return graph
|
||
|
||
|
||
def aug_traffic(sim_mx, data, percent=0.1):
|
||
return data
|
||
|
||
|
||
class STEncoder(nn.Module):
|
||
def __init__(self, Kt, Ks, blocks, input_length, num_nodes, droprate):
|
||
super(STEncoder, self).__init__()
|
||
self.num_nodes = num_nodes
|
||
self.input_length = input_length
|
||
|
||
# 简化的编码器 - 修复输入通道数
|
||
self.conv1 = nn.Conv2d(
|
||
blocks[0][0], blocks[0][1], kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2)
|
||
)
|
||
self.conv2 = nn.Conv2d(
|
||
blocks[0][1], blocks[0][2], kernel_size=(Kt, Ks), padding=(Kt // 2, Ks // 2)
|
||
)
|
||
self.dropout = nn.Dropout(droprate)
|
||
|
||
# 临时的相似度矩阵
|
||
self.s_sim_mx = torch.randn(num_nodes, num_nodes)
|
||
self.t_sim_mx = torch.randn(input_length, input_length)
|
||
|
||
def forward(self, x, graph):
|
||
# x: (batch_size, num_nodes, seq_len, features)
|
||
batch_size, num_nodes, seq_len, features = x.shape
|
||
|
||
# 调整维度
|
||
x = x.permute(0, 3, 1, 2) # (batch_size, features, num_nodes, seq_len)
|
||
|
||
# 确保输入通道数正确
|
||
if x.shape[1] != 2: # 如果不是2个通道,需要调整
|
||
if x.shape[1] == 1:
|
||
x = x.repeat(1, 2, 1, 1) # 复制到2个通道
|
||
else:
|
||
x = x[:, :2, :, :] # 取前2个通道
|
||
|
||
# 卷积操作
|
||
x = F.relu(self.conv1(x))
|
||
x = self.dropout(x)
|
||
x = F.relu(self.conv2(x))
|
||
|
||
# 调整回原维度
|
||
x = x.permute(0, 2, 3, 1) # (batch_size, num_nodes, seq_len, features)
|
||
|
||
return x
|
||
|
||
|
||
class MLP(nn.Module):
|
||
def __init__(self, input_dim, output_dim):
|
||
super(MLP, self).__init__()
|
||
self.fc = nn.Linear(input_dim, output_dim)
|
||
|
||
def forward(self, x):
|
||
return self.fc(x)
|
||
|
||
|
||
class TemporalHeteroModel(nn.Module):
|
||
def __init__(self, d_model, batch_size, num_nodes, device):
|
||
super(TemporalHeteroModel, self).__init__()
|
||
self.fc = nn.Linear(d_model, 1)
|
||
|
||
def forward(self, z1, z2):
|
||
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||
|
||
|
||
class SpatialHeteroModel(nn.Module):
|
||
def __init__(self, d_model, nmb_prototype, batch_size, shm_temp):
|
||
super(SpatialHeteroModel, self).__init__()
|
||
self.fc = nn.Linear(d_model, 1)
|
||
|
||
def forward(self, z1, z2):
|
||
return F.mse_loss(self.fc(z1), self.fc(z2))
|
||
|
||
|
||
class STSSL(nn.Module):
|
||
def __init__(self, args):
|
||
super(STSSL, self).__init__()
|
||
# spatial temporal encoder
|
||
self.encoder = STEncoder(
|
||
Kt=3,
|
||
Ks=3,
|
||
blocks=[
|
||
[2, int(args["d_model"] // 2), args["d_model"]],
|
||
[args["d_model"], int(args["d_model"] // 2), args["d_model"]],
|
||
],
|
||
input_length=args["input_length"],
|
||
num_nodes=args["num_nodes"],
|
||
droprate=args["dropout"],
|
||
)
|
||
|
||
# traffic flow prediction branch
|
||
self.mlp = MLP(args["d_model"], args["d_output"])
|
||
# temporal heterogenrity modeling branch
|
||
self.thm = TemporalHeteroModel(
|
||
args["d_model"], args["batch_size"], args["num_nodes"], args["device"]
|
||
)
|
||
# spatial heterogenrity modeling branch
|
||
self.shm = SpatialHeteroModel(
|
||
args["d_model"], args["nmb_prototype"], args["batch_size"], args["shm_temp"]
|
||
)
|
||
self.mae = masked_mae_loss(mask_value=5.0)
|
||
self.args = args
|
||
|
||
def forward(self, view1, graph):
|
||
repr1 = self.encoder(view1, graph) # view1: n,l,v,c; graph: v,v
|
||
|
||
s_sim_mx = self.fetch_spatial_sim()
|
||
graph2 = aug_topology(s_sim_mx, graph, percent=self.args["percent"] * 2)
|
||
|
||
t_sim_mx = self.fetch_temporal_sim()
|
||
view2 = aug_traffic(t_sim_mx, view1, percent=self.args["percent"])
|
||
|
||
repr2 = self.encoder(view2, graph2)
|
||
return repr1, repr2
|
||
|
||
def fetch_spatial_sim(self):
|
||
"""
|
||
Fetch the region similarity matrix generated by region embedding.
|
||
Note this can be called only when spatial_sim is True.
|
||
:return sim_mx: tensor, similarity matrix, (v, v)
|
||
"""
|
||
return self.encoder.s_sim_mx.cpu()
|
||
|
||
def fetch_temporal_sim(self):
|
||
return self.encoder.t_sim_mx.cpu()
|
||
|
||
def predict(self, z1, z2):
|
||
"""Predicting future traffic flow.
|
||
:param z1, z2 (tensor): shape nvc
|
||
:return: nlvc, l=1, c=2
|
||
"""
|
||
return self.mlp(z1)
|
||
|
||
def loss(self, z1, z2, y_true, scaler, loss_weights):
|
||
l1 = self.pred_loss(z1, z2, y_true, scaler)
|
||
sep_loss = [l1.item()]
|
||
loss = loss_weights[0] * l1
|
||
|
||
l2 = self.temporal_loss(z1, z2)
|
||
sep_loss.append(l2.item())
|
||
loss += loss_weights[1] * l2
|
||
|
||
l3 = self.spatial_loss(z1, z2)
|
||
sep_loss.append(l3.item())
|
||
loss += loss_weights[2] * l3
|
||
return loss, sep_loss
|
||
|
||
def pred_loss(self, z1, z2, y_true, scaler):
|
||
y_pred = scaler.inverse_transform(self.predict(z1, z2))
|
||
y_true = scaler.inverse_transform(y_true)
|
||
|
||
loss = self.args["yita"] * self.mae(y_pred[..., 0], y_true[..., 0]) + (
|
||
1 - self.args["yita"]
|
||
) * self.mae(y_pred[..., 1], y_true[..., 1])
|
||
return loss
|
||
|
||
def temporal_loss(self, z1, z2):
|
||
return self.thm(z1, z2)
|
||
|
||
def spatial_loss(self, z1, z2):
|
||
return self.shm(z1, z2)
|