Fix some bugs can not run

This commit is contained in:
HengZhang 2024-11-21 13:14:48 +08:00
parent 9a3316d5d4
commit d236ba5eae
2 changed files with 50 additions and 0 deletions

View File

@ -0,0 +1,46 @@
from federatedscope.register import register_criterion
def call_my_criterion(type, device):
try:
import torch
import torch.nn as nn
except ImportError:
nn = None
criterion = None
class RMSELoss(nn.Module):
def __init__(self):
super(RMSELoss, self).__init__()
self.mse = nn.MSELoss()
def forward(self, y_pred, y_true):
return torch.sqrt(self.mse(y_pred, y_true))
class MAPELoss(nn.Module):
def __init__(self, epsilon=1e-10):
super(MAPELoss, self).__init__()
self.epsilon = epsilon
def forward(self, y_pred, y_true):
mask_value = 0.1
if mask_value is not None:
mask = torch.gt(y_true, mask_value)
pred = torch.masked_select(y_pred, mask)
true = torch.masked_select(y_true, mask)
return torch.mean(torch.abs(torch.div((true - pred), (true + 0.001)))) * 100
if type == 'RMSE':
if nn is not None:
criterion = RMSELoss().to(device)
elif type == 'MAPE':
if nn is not None:
criterion = MAPELoss().to(device)
else:
criterion = None
return criterion
# Register the custom RMSE and MAPE criterion
register_criterion('RMSE', call_my_criterion)
register_criterion('MAPE', call_my_criterion)

View File

@ -62,6 +62,10 @@ def get_dataloader(dataset, config, split='train'):
elif config.dataloader.type == 'mf':
from federatedscope.mf.dataloader import MFDataLoader
loader_cls = MFDataLoader
elif config.dataloader.type == 'trafficflow':
# 待定
from torch.utils.data import DataLoader
loader_cls = DataLoader
else:
raise ValueError(f'data.loader.type {config.data.loader.type} '
f'not found!')