50 lines
1.4 KiB
Python
50 lines
1.4 KiB
Python
from federatedscope.register import register_criterion
|
|
|
|
"""
|
|
Adding RMSE, MAPE for traffic flow prediction
|
|
"""
|
|
def TFP_criterion(type, device):
|
|
try:
|
|
import torch
|
|
import torch.nn as nn
|
|
except ImportError:
|
|
nn = None
|
|
criterion = None
|
|
|
|
class RMSELoss(nn.Module):
|
|
def __init__(self):
|
|
super(RMSELoss, self).__init__()
|
|
self.mse = nn.MSELoss()
|
|
|
|
def forward(self, y_pred, y_true):
|
|
return torch.sqrt(self.mse(y_pred, y_true))
|
|
|
|
class MAPELoss(nn.Module):
|
|
def __init__(self, epsilon=1e-10):
|
|
super(MAPELoss, self).__init__()
|
|
self.epsilon = epsilon
|
|
|
|
def forward(self, y_pred, y_true):
|
|
mask_value = 0.1
|
|
if mask_value is not None:
|
|
mask = torch.gt(y_true, mask_value)
|
|
pred = torch.masked_select(y_pred, mask)
|
|
true = torch.masked_select(y_true, mask)
|
|
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) * 100
|
|
|
|
|
|
if type == 'RMSE':
|
|
if nn is not None:
|
|
criterion = RMSELoss().to(device)
|
|
elif type == 'MAPE':
|
|
if nn is not None:
|
|
criterion = MAPELoss().to(device)
|
|
else:
|
|
criterion = None
|
|
|
|
return criterion
|
|
|
|
# Register the custom RMSE and MAPE criterion
|
|
register_criterion('RMSE', TFP_criterion)
|
|
register_criterion('MAPE', TFP_criterion)
|