TrafficWheel/utils/loss_function.py

67 lines
2.1 KiB
Python
Executable File

def masked_mae_loss(scaler, mask_value):
def loss(preds, labels):
# 仅对预测反归一化;标签在数据管道中保持原始量纲
if scaler:
preds = scaler.inverse_transform(preds)
return mae_torch(pred=preds, true=labels, mask_value=mask_value)
return loss
def get_loss_function(args, scaler):
if args["loss_func"] == "mask_mae":
# Return callable loss (no .to for function closures); disable masking by default
return masked_mae_loss(scaler, mask_value=None)
elif args["loss_func"] == "mae":
return torch.nn.L1Loss().to(args["device"])
elif args["loss_func"] == "mse":
return torch.nn.MSELoss().to(args["device"])
elif args["loss_func"] == "Huber":
return torch.nn.HuberLoss().to(args["device"])
else:
raise ValueError("Unsupported loss function: {}".format(args.loss_func))
import torch
def mae_torch(pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.mean(torch.abs(true - pred))
def rmse_torch(pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.sqrt(torch.mean((pred - true) ** 2))
def mape_torch(pred, true, mask_value=None):
if mask_value is not None:
mask = torch.gt(true, mask_value)
pred = torch.masked_select(pred, mask)
true = torch.masked_select(true, mask)
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001))))
def all_metrics(pred, true, mask1, mask2):
if mask1 == "None":
mask1 = None
if mask2 == "None":
mask2 = None
mae = mae_torch(pred, true, mask1)
rmse = rmse_torch(pred, true, mask1)
mape = mape_torch(pred, true, mask2)
return mae, rmse, mape
if __name__ == "__main__":
pred = torch.Tensor([1, 2, 3, 4])
true = torch.Tensor([2, 1, 4, 5])
print(all_metrics(pred, true, None, None))