FS-TFP/scripts/distributed_scripts/gen_data.py

110 lines
3.4 KiB
Python

import numpy as np
import pickle
import os
def generate_data(client_num=3,
instance_num=1000,
feature_num=5,
save_data=True):
"""
Generate data in Runner format
Args:
client_num:
instance_num:
feature_num:
save_data:
Returns:
{
'{client_id}': {
'train': {
'x': ...,
'y': ...
},
'test': {
'x': ...,
'y': ...
},
'val': {
'x': ...,
'y': ...
}
}
}
"""
weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
bias = np.random.normal(loc=0.0, scale=1.0)
data = dict()
for each_client in range(1, client_num + 1):
data[each_client] = dict()
client_x = np.random.normal(loc=0.0,
scale=0.5 * each_client,
size=(instance_num, feature_num))
client_y = np.sum(client_x * weights, axis=-1) + bias
client_y = np.expand_dims(client_y, -1)
client_data = {'x': client_x, 'y': client_y}
data[each_client]['train'] = client_data
# test data
test_x = np.random.normal(loc=0.0,
scale=1.0,
size=(instance_num, feature_num))
test_y = np.sum(test_x * weights, axis=-1) + bias
test_y = np.expand_dims(test_y, -1)
test_data = {'x': test_x, 'y': test_y}
for each_client in range(1, client_num + 1):
data[each_client]['test'] = test_data
# val data
val_x = np.random.normal(loc=0.0,
scale=1.0,
size=(instance_num, feature_num))
val_y = np.sum(val_x * weights, axis=-1) + bias
val_y = np.expand_dims(val_y, -1)
val_data = {'x': val_x, 'y': val_y}
for each_client in range(1, client_num + 1):
data[each_client]['val'] = val_data
# server_data
data[0] = dict()
data[0]['train'] = None
data[0]['val'] = val_data
data[0]['test'] = test_data
if save_data:
# server_data = dict()
save_client_data = dict()
if not os.path.exists('toy_data'):
os.makedirs('toy_data')
for client_idx in range(0, client_num + 1):
if client_idx == 0:
filename = 'toy_data/server_data'
else:
filename = 'toy_data/client_{:d}_data'.format(client_idx)
with open(filename, 'wb') as f:
save_client_data['train'] = {
k: v.tolist()
for k, v in data[client_idx]['train'].items()
} if data[client_idx]['train'] is not None else None
save_client_data['val'] = {
k: v.tolist()
for k, v in data[client_idx]['val'].items()
} if data[client_idx]['val'] is not None else None
save_client_data['test'] = {
k: v.tolist()
for k, v in data[client_idx]['test'].items()
} if data[client_idx]['test'] is not None else None
pickle.dump(save_client_data, f)
with open('toy_data/all_data', 'wb') as f:
pickle.dump(data, f)
return data
data = generate_data()