FS-TFP/federatedscope/tabular/dataloader/toy.py

115 lines
3.9 KiB
Python

import copy
import pickle
import numpy as np
from federatedscope.core.data.wrap_dataset import WrapDataset
def load_toy_data(config=None):
def _generate_data(client_num=5,
instance_num=1000,
feature_num=5,
save_data=False):
"""
Generate data in FedRunner format
Args:
client_num:
instance_num:
feature_num:
save_data:
Returns:
{
'{client_id}': {
'train': {
'x': ...,
'y': ...
},
'test': {
'x': ...,
'y': ...
},
'val': {
'x': ...,
'y': ...
}
}
}
"""
weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
bias = np.random.normal(loc=0.0, scale=1.0)
data = dict()
for each_client in range(1, client_num + 1):
data[each_client] = dict()
client_x = np.random.normal(loc=0.0,
scale=0.5 * each_client,
size=(instance_num, feature_num))
client_y = np.sum(client_x * weights, axis=-1) + bias
client_y = np.expand_dims(client_y, -1)
client_data = {'x': client_x, 'y': client_y}
data[each_client]['train'] = client_data
# test data
test_x = np.random.normal(loc=0.0,
scale=1.0,
size=(instance_num, feature_num))
test_y = np.sum(test_x * weights, axis=-1) + bias
test_y = np.expand_dims(test_y, -1)
test_data = {'x': test_x, 'y': test_y}
for each_client in range(1, client_num + 1):
data[each_client]['test'] = copy.deepcopy(test_data)
# val data
val_x = np.random.normal(loc=0.0,
scale=1.0,
size=(instance_num, feature_num))
val_y = np.sum(val_x * weights, axis=-1) + bias
val_y = np.expand_dims(val_y, -1)
val_data = {'x': val_x, 'y': val_y}
for each_client in range(1, client_num + 1):
data[each_client]['val'] = copy.deepcopy(val_data)
# server_data
data[0] = dict()
# data[0]['train'] = None
data[0]['val'] = val_data
data[0]['test'] = test_data
if save_data:
# server_data = dict()
save_client_data = dict()
for client_idx in range(0, client_num + 1):
if client_idx == 0:
filename = 'data/server_data'
else:
filename = 'data/client_{:d}_data'.format(client_idx)
with open(filename, 'wb') as f:
save_client_data['train'] = {
k: v.tolist()
for k, v in data[client_idx]['train'].items()
}
save_client_data['val'] = {
k: v.tolist()
for k, v in data[client_idx]['val'].items()
}
save_client_data['test'] = {
k: v.tolist()
for k, v in data[client_idx]['test'].items()
}
pickle.dump(save_client_data, f)
return data
data = _generate_data(client_num=config.federate.client_num,
save_data=config.data.save_data)
for client_id in data.keys():
data[client_id] = {
k: WrapDataset(v)
for k, v in data[client_id].items()
} if data[client_id] is not None else None
return data, config