32 lines
880 B
Python
32 lines
880 B
Python
import numpy as np
|
|
from torch import nn
|
|
from lib.loss_function import mae_torch
|
|
|
|
def step_loss(prediction, real_value, theta, priori_adj, gsl_coefficient, null_val=np.nan):
|
|
"""STEP模型的损失函数
|
|
|
|
Args:
|
|
prediction: 预测值
|
|
real_value: 真实值
|
|
theta: Bernoulli分布参数
|
|
priori_adj: 先验邻接矩阵
|
|
gsl_coefficient: 图结构学习损失系数
|
|
null_val: 空值
|
|
|
|
Returns:
|
|
loss: 总损失
|
|
"""
|
|
# graph structure learning loss
|
|
B, N, N = theta.shape
|
|
theta = theta.view(B, N*N)
|
|
tru = priori_adj.view(B, N*N)
|
|
BCE_loss = nn.BCELoss()
|
|
loss_graph = BCE_loss(theta, tru)
|
|
|
|
# prediction loss
|
|
loss_pred = mae_torch(pred=prediction, true=real_value, mask_value=null_val)
|
|
|
|
# final loss
|
|
loss = loss_pred + loss_graph * gsl_coefficient
|
|
return loss
|