110 lines
3.4 KiB
Python
110 lines
3.4 KiB
Python
import numpy as np
|
|
import pickle
|
|
import os
|
|
|
|
|
|
def generate_data(client_num=3,
|
|
instance_num=1000,
|
|
feature_num=5,
|
|
save_data=True):
|
|
"""
|
|
Generate data in Runner format
|
|
Args:
|
|
client_num:
|
|
instance_num:
|
|
feature_num:
|
|
save_data:
|
|
|
|
Returns:
|
|
{
|
|
'{client_id}': {
|
|
'train': {
|
|
'x': ...,
|
|
'y': ...
|
|
},
|
|
'test': {
|
|
'x': ...,
|
|
'y': ...
|
|
},
|
|
'val': {
|
|
'x': ...,
|
|
'y': ...
|
|
}
|
|
}
|
|
}
|
|
|
|
"""
|
|
weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
|
|
bias = np.random.normal(loc=0.0, scale=1.0)
|
|
data = dict()
|
|
for each_client in range(1, client_num + 1):
|
|
data[each_client] = dict()
|
|
client_x = np.random.normal(loc=0.0,
|
|
scale=0.5 * each_client,
|
|
size=(instance_num, feature_num))
|
|
client_y = np.sum(client_x * weights, axis=-1) + bias
|
|
client_y = np.expand_dims(client_y, -1)
|
|
client_data = {'x': client_x, 'y': client_y}
|
|
data[each_client]['train'] = client_data
|
|
|
|
# test data
|
|
test_x = np.random.normal(loc=0.0,
|
|
scale=1.0,
|
|
size=(instance_num, feature_num))
|
|
test_y = np.sum(test_x * weights, axis=-1) + bias
|
|
test_y = np.expand_dims(test_y, -1)
|
|
test_data = {'x': test_x, 'y': test_y}
|
|
for each_client in range(1, client_num + 1):
|
|
data[each_client]['test'] = test_data
|
|
|
|
# val data
|
|
val_x = np.random.normal(loc=0.0,
|
|
scale=1.0,
|
|
size=(instance_num, feature_num))
|
|
val_y = np.sum(val_x * weights, axis=-1) + bias
|
|
val_y = np.expand_dims(val_y, -1)
|
|
val_data = {'x': val_x, 'y': val_y}
|
|
for each_client in range(1, client_num + 1):
|
|
data[each_client]['val'] = val_data
|
|
|
|
# server_data
|
|
data[0] = dict()
|
|
data[0]['train'] = None
|
|
data[0]['val'] = val_data
|
|
data[0]['test'] = test_data
|
|
|
|
if save_data:
|
|
# server_data = dict()
|
|
save_client_data = dict()
|
|
|
|
if not os.path.exists('toy_data'):
|
|
os.makedirs('toy_data')
|
|
|
|
for client_idx in range(0, client_num + 1):
|
|
if client_idx == 0:
|
|
filename = 'toy_data/server_data'
|
|
else:
|
|
filename = 'toy_data/client_{:d}_data'.format(client_idx)
|
|
with open(filename, 'wb') as f:
|
|
save_client_data['train'] = {
|
|
k: v.tolist()
|
|
for k, v in data[client_idx]['train'].items()
|
|
} if data[client_idx]['train'] is not None else None
|
|
save_client_data['val'] = {
|
|
k: v.tolist()
|
|
for k, v in data[client_idx]['val'].items()
|
|
} if data[client_idx]['val'] is not None else None
|
|
save_client_data['test'] = {
|
|
k: v.tolist()
|
|
for k, v in data[client_idx]['test'].items()
|
|
} if data[client_idx]['test'] is not None else None
|
|
pickle.dump(save_client_data, f)
|
|
|
|
with open('toy_data/all_data', 'wb') as f:
|
|
pickle.dump(data, f)
|
|
|
|
return data
|
|
|
|
|
|
data = generate_data()
|