Fix some bugs can not run
This commit is contained in:
parent
9a3316d5d4
commit
d236ba5eae
|
|
@ -0,0 +1,46 @@
|
|||
from federatedscope.register import register_criterion
|
||||
|
||||
def call_my_criterion(type, device):
|
||||
try:
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
except ImportError:
|
||||
nn = None
|
||||
criterion = None
|
||||
|
||||
class RMSELoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(RMSELoss, self).__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
return torch.sqrt(self.mse(y_pred, y_true))
|
||||
|
||||
class MAPELoss(nn.Module):
|
||||
def __init__(self, epsilon=1e-10):
|
||||
super(MAPELoss, self).__init__()
|
||||
self.epsilon = epsilon
|
||||
|
||||
def forward(self, y_pred, y_true):
|
||||
mask_value = 0.1
|
||||
if mask_value is not None:
|
||||
mask = torch.gt(y_true, mask_value)
|
||||
pred = torch.masked_select(y_pred, mask)
|
||||
true = torch.masked_select(y_true, mask)
|
||||
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) * 100
|
||||
|
||||
|
||||
if type == 'RMSE':
|
||||
if nn is not None:
|
||||
criterion = RMSELoss().to(device)
|
||||
elif type == 'MAPE':
|
||||
if nn is not None:
|
||||
criterion = MAPELoss().to(device)
|
||||
else:
|
||||
criterion = None
|
||||
|
||||
return criterion
|
||||
|
||||
# Register the custom RMSE and MAPE criterion
|
||||
register_criterion('RMSE', call_my_criterion)
|
||||
register_criterion('MAPE', call_my_criterion)
|
||||
|
|
@ -62,6 +62,10 @@ def get_dataloader(dataset, config, split='train'):
|
|||
elif config.dataloader.type == 'mf':
|
||||
from federatedscope.mf.dataloader import MFDataLoader
|
||||
loader_cls = MFDataLoader
|
||||
elif config.dataloader.type == 'trafficflow':
|
||||
# 待定
|
||||
from torch.utils.data import DataLoader
|
||||
loader_cls = DataLoader
|
||||
else:
|
||||
raise ValueError(f'data.loader.type {config.data.loader.type} '
|
||||
f'not found!')
|
||||
|
|
|
|||
Loading…
Reference in New Issue