import torch def masked_mae_loss(y_pred, y_true): mask = (y_true != 0).float() mask /= mask.mean() loss = torch.abs(y_pred - y_true) loss = loss * mask return loss.mean()