56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class MaskedMAELoss(nn.Module):
|
|
def __init__(self, scaler, mask_value):
|
|
super(MaskedMAELoss, self).__init__()
|
|
self.scaler = scaler
|
|
self.mask_value = mask_value
|
|
|
|
def forward(self, preds, labels):
|
|
if self.scaler:
|
|
preds = self.scaler.inverse_transform(preds)
|
|
labels = self.scaler.inverse_transform(labels)
|
|
return mae_torch(pred=preds, true=labels, mask_value=self.mask_value)
|
|
|
|
def masked_mae_loss(scaler, mask_value):
|
|
"""保持向后兼容性的函数"""
|
|
return MaskedMAELoss(scaler, mask_value)
|
|
|
|
def mae_torch(pred, true, mask_value=None):
|
|
if mask_value is not None:
|
|
mask = torch.gt(true, mask_value)
|
|
pred = torch.masked_select(pred, mask)
|
|
true = torch.masked_select(true, mask)
|
|
return torch.mean(torch.abs(true - pred))
|
|
|
|
|
|
def rmse_torch(pred, true, mask_value=None):
|
|
if mask_value is not None:
|
|
mask = torch.gt(true, mask_value)
|
|
pred = torch.masked_select(pred, mask)
|
|
true = torch.masked_select(true, mask)
|
|
return torch.sqrt(torch.mean((pred - true) ** 2))
|
|
|
|
|
|
def mape_torch(pred, true, mask_value=None):
|
|
if mask_value is not None:
|
|
mask = torch.gt(true, mask_value)
|
|
pred = torch.masked_select(pred, mask)
|
|
true = torch.masked_select(true, mask)
|
|
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001))))
|
|
|
|
|
|
def all_metrics(pred, true, mask1, mask2):
|
|
if mask1 == 'None': mask1 = None
|
|
if mask2 == 'None': mask2 = None
|
|
mae = mae_torch(pred, true, mask1)
|
|
rmse = rmse_torch(pred, true, mask1)
|
|
mape = mape_torch(pred, true, mask2)
|
|
return mae, rmse, mape
|
|
|
|
|
|
if __name__ == '__main__':
|
|
pred = torch.Tensor([1, 2, 3, 4])
|
|
true = torch.Tensor([2, 1, 4, 5])
|
|
print(all_metrics(pred, true, None, None)) |