TrafficWheel/model/STEP/step_loss.py

32 lines
880 B
Python

import numpy as np
from torch import nn
from lib.loss_function import mae_torch
def step_loss(prediction, real_value, theta, priori_adj, gsl_coefficient, null_val=np.nan):
"""STEP模型的损失函数
Args:
prediction: 预测值
real_value: 真实值
theta: Bernoulli分布参数
priori_adj: 先验邻接矩阵
gsl_coefficient: 图结构学习损失系数
null_val: 空值
Returns:
loss: 总损失
"""
# graph structure learning loss
B, N, N = theta.shape
theta = theta.view(B, N*N)
tru = priori_adj.view(B, N*N)
BCE_loss = nn.BCELoss()
loss_graph = BCE_loss(theta, tru)
# prediction loss
loss_pred = mae_torch(pred=prediction, true=real_value, mask_value=null_val)
# final loss
loss = loss_pred + loss_graph * gsl_coefficient
return loss