FS-TFP/federatedscope/vertical_fl/dataloader/dataloader.py

86 lines
2.6 KiB
Python

import numpy as np
from federatedscope.vertical_fl.dataset import Adult, Abalone, Credit, Blog
from federatedscope.core.data.wrap_dataset import WrapDataset
VERTICAL_DATASET = {
'adult': Adult,
'credit': Credit,
'abalone': Abalone,
'blog': Blog,
}
def load_vertical_data(config=None, generate=False):
"""
To load data for vertical FL
Arguments:
config: configuration
generate (bool): whether to generate the synthetic data
:returns: The data, the modified config
:rtype: dict
"""
dataset_name = config.data.type.lower()
# TODO: merge the following later
if config.data.args:
args = config.data.args[0]
else:
args = {'normalization': False, 'standardization': False}
if not generate:
dataset_class = VERTICAL_DATASET[dataset_name]
dataset = dataset_class(root=config.data.root,
num_of_clients=config.federate.client_num,
feature_partition=config.vertical.dims,
tr_frac=config.data.splits[0],
algo=config.vertical.algo,
debug_size=config.vertical.data_size_for_debug,
download=True,
seed=1234,
args=args)
data = dataset.data
return data, config
else:
# generate toy data for running a vertical FL example
INSTANCE_NUM = 1000
TRAIN_SPLIT = 0.9
total_dims = config.vertical.dims[-1]
theta = np.random.uniform(low=-1.0, high=1.0, size=(total_dims, 1))
x = np.random.choice([-1.0, 1.0, -2.0, 2.0, -3.0, 3.0],
size=(INSTANCE_NUM, total_dims))
y = np.asarray([
1.0 if x >= 0 else -1.0
for x in np.reshape(np.matmul(x, theta), -1)
])
train_num = int(TRAIN_SPLIT * INSTANCE_NUM)
test_data = {'theta': theta, 'x': x[train_num:], 'y': y[train_num:]}
data = dict()
# For Server
data[0] = dict()
data[0]['train'] = None
data[0]['val'] = None
data[0]['test'] = test_data
# For Client #1
data[1] = dict()
data[1]['train'] = {'x': x[:train_num, :config.vertical.dims[0]]}
data[1]['val'] = None
data[1]['test'] = test_data
# For Client #2
data[2] = dict()
data[2]['train'] = {
'x': x[:train_num, config.vertical.dims[0]:],
'y': y[:train_num]
}
data[2]['val'] = None
data[2]['test'] = test_data
return data, config