DCRNN/model/pytorch/loss.py

10 lines
193 B
Python

import torch
def masked_mae_loss(y_pred, y_true):
mask = (y_true != 0).float()
mask /= mask.mean()
loss = torch.abs(y_pred - y_true)
loss = loss * mask
return loss.mean()