diff --git a/federatedscope/contrib/loss/RMSE.py b/federatedscope/contrib/loss/RMSE.py new file mode 100644 index 0000000..6ad80ff --- /dev/null +++ b/federatedscope/contrib/loss/RMSE.py @@ -0,0 +1,46 @@ +from federatedscope.register import register_criterion + +def call_my_criterion(type, device): + try: + import torch + import torch.nn as nn + except ImportError: + nn = None + criterion = None + + class RMSELoss(nn.Module): + def __init__(self): + super(RMSELoss, self).__init__() + self.mse = nn.MSELoss() + + def forward(self, y_pred, y_true): + return torch.sqrt(self.mse(y_pred, y_true)) + + class MAPELoss(nn.Module): + def __init__(self, epsilon=1e-10): + super(MAPELoss, self).__init__() + self.epsilon = epsilon + + def forward(self, y_pred, y_true): + mask_value = 0.1 + if mask_value is not None: + mask = torch.gt(y_true, mask_value) + pred = torch.masked_select(y_pred, mask) + true = torch.masked_select(y_true, mask) + return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) * 100 + + + if type == 'RMSE': + if nn is not None: + criterion = RMSELoss().to(device) + elif type == 'MAPE': + if nn is not None: + criterion = MAPELoss().to(device) + else: + criterion = None + + return criterion + +# Register the custom RMSE and MAPE criterion +register_criterion('RMSE', call_my_criterion) +register_criterion('MAPE', call_my_criterion) diff --git a/federatedscope/core/auxiliaries/dataloader_builder.py b/federatedscope/core/auxiliaries/dataloader_builder.py index 4b95741..5eea675 100644 --- a/federatedscope/core/auxiliaries/dataloader_builder.py +++ b/federatedscope/core/auxiliaries/dataloader_builder.py @@ -62,6 +62,10 @@ def get_dataloader(dataset, config, split='train'): elif config.dataloader.type == 'mf': from federatedscope.mf.dataloader import MFDataLoader loader_cls = MFDataLoader + elif config.dataloader.type == 'trafficflow': + # 待定 + from torch.utils.data import DataLoader + loader_cls = DataLoader else: raise ValueError(f'data.loader.type {config.data.loader.type} ' f'not found!')