import numpy as np from torch import nn from lib.loss_function import mae_torch def step_loss(prediction, real_value, theta, priori_adj, gsl_coefficient, null_val=np.nan): """STEP模型的损失函数 Args: prediction: 预测值 real_value: 真实值 theta: Bernoulli分布参数 priori_adj: 先验邻接矩阵 gsl_coefficient: 图结构学习损失系数 null_val: 空值 Returns: loss: 总损失 """ # graph structure learning loss B, N, N = theta.shape theta = theta.view(B, N*N) tru = priori_adj.view(B, N*N) BCE_loss = nn.BCELoss() loss_graph = BCE_loss(theta, tru) # prediction loss loss_pred = mae_torch(pred=prediction, true=real_value, mask_value=null_val) # final loss loss = loss_pred + loss_graph * gsl_coefficient return loss