144 lines
5.2 KiB
Python
144 lines
5.2 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from federatedscope.core.auxiliaries.splitter_builder import get_splitter
|
|
from federatedscope.core.data import ClientData, StandaloneDataDict
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BaseDataTranslator:
|
|
"""
|
|
Translator is a tool to convert a centralized dataset to \
|
|
``StandaloneDataDict``, which is the input data of runner.
|
|
|
|
Notes:
|
|
The ``Translator`` is consist of several stages:
|
|
|
|
Dataset -> ML split (``split_train_val_test()``) -> \
|
|
FL split (``split_to_client()``) -> ``StandaloneDataDict``
|
|
|
|
"""
|
|
def __init__(self, global_cfg, client_cfgs=None):
|
|
"""
|
|
Convert data to `StandaloneDataDict`.
|
|
|
|
Args:
|
|
global_cfg: global CfgNode
|
|
client_cfgs: client cfg `Dict`
|
|
"""
|
|
self.global_cfg = global_cfg
|
|
self.client_cfgs = client_cfgs
|
|
self.splitter = get_splitter(global_cfg)
|
|
|
|
def __call__(self, dataset):
|
|
"""
|
|
Args:
|
|
dataset: `torch.utils.data.Dataset`, `List` of (feature, label)
|
|
or split dataset tuple of (train, val, test) or Tuple of
|
|
split dataset with [train, val, test]
|
|
|
|
Returns:
|
|
datadict: instance of `StandaloneDataDict`, which is a subclass of
|
|
`dict`.
|
|
"""
|
|
datadict = self.split(dataset)
|
|
datadict = StandaloneDataDict(datadict, self.global_cfg)
|
|
|
|
return datadict
|
|
|
|
def split(self, dataset):
|
|
"""
|
|
Perform ML split and FL split.
|
|
|
|
Returns:
|
|
dict of ``ClientData`` with client_idx as key to build \
|
|
``StandaloneDataDict``
|
|
"""
|
|
train, val, test = self.split_train_val_test(dataset)
|
|
datadict = self.split_to_client(train, val, test)
|
|
return datadict
|
|
|
|
def split_train_val_test(self, dataset, cfg=None):
|
|
"""
|
|
Split dataset to train, val, test if not provided.
|
|
|
|
Returns:
|
|
List: List of split dataset, like ``[train, val, test]``
|
|
"""
|
|
from torch.utils.data import Dataset, Subset
|
|
|
|
if cfg is not None:
|
|
splits = cfg.data.splits
|
|
else:
|
|
splits = self.global_cfg.data.splits
|
|
if isinstance(dataset, tuple):
|
|
# No need to split train/val/test for tuple dataset.
|
|
error_msg = 'If dataset is tuple, it must contains ' \
|
|
'train, valid and test split.'
|
|
assert len(dataset) == len(['train', 'val', 'test']), error_msg
|
|
return [dataset[0], dataset[1], dataset[2]]
|
|
|
|
index = np.random.permutation(np.arange(len(dataset)))
|
|
train_size = int(splits[0] * len(dataset))
|
|
val_size = int(splits[1] * len(dataset))
|
|
|
|
if isinstance(dataset, Dataset):
|
|
train_dataset = Subset(dataset, index[:train_size])
|
|
val_dataset = Subset(dataset,
|
|
index[train_size:train_size + val_size])
|
|
test_dataset = Subset(dataset, index[train_size + val_size:])
|
|
else:
|
|
train_dataset = [dataset[x] for x in index[:train_size]]
|
|
val_dataset = [
|
|
dataset[x] for x in index[train_size:train_size + val_size]
|
|
]
|
|
test_dataset = [dataset[x] for x in index[train_size + val_size:]]
|
|
return train_dataset, val_dataset, test_dataset
|
|
|
|
def split_to_client(self, train, val, test):
|
|
"""
|
|
Split dataset to clients and build ``ClientData``.
|
|
|
|
Returns:
|
|
dict: dict of ``ClientData`` with ``client_idx`` as key.
|
|
"""
|
|
|
|
# Initialization
|
|
client_num = self.global_cfg.federate.client_num
|
|
split_train, split_val, split_test = [[None] * client_num] * 3
|
|
train_label_distribution = None
|
|
|
|
# Split train/val/test to client
|
|
if len(train) > 0:
|
|
split_train = self.splitter(train)
|
|
if self.global_cfg.data.consistent_label_distribution:
|
|
try:
|
|
train_label_distribution = [[j[1] for j in x]
|
|
for x in split_train]
|
|
except:
|
|
logger.warning(
|
|
'Cannot access train label distribution for '
|
|
'splitter.')
|
|
if len(val) > 0:
|
|
split_val = self.splitter(val, prior=train_label_distribution)
|
|
if len(test) > 0:
|
|
split_test = self.splitter(test, prior=train_label_distribution)
|
|
|
|
# Build data dict with `ClientData`, key `0` for server.
|
|
data_dict = {
|
|
0: ClientData(self.global_cfg, train=train, val=val, test=test)
|
|
}
|
|
for client_id in range(1, client_num + 1):
|
|
if self.client_cfgs is not None:
|
|
client_cfg = self.global_cfg.clone()
|
|
client_cfg.merge_from_other_cfg(
|
|
self.client_cfgs.get(f'client_{client_id}'))
|
|
else:
|
|
client_cfg = self.global_cfg
|
|
data_dict[client_id] = ClientData(client_cfg,
|
|
train=split_train[client_id - 1],
|
|
val=split_val[client_id - 1],
|
|
test=split_test[client_id - 1])
|
|
return data_dict
|