FS-TFP/federatedscope/tabular/dataloader/quadratic.py

20 lines
670 B
Python

import numpy as np
def load_quadratic_dataset(config):
data_dict = dict()
d = config.data.quadratic.dim
base = np.exp(
np.log(config.data.quadratic.max_curv / config.data.quadratic.min_curv)
/ (config.federate.client_num - 1))
for i in range(1, 1 + config.federate.client_num):
# TODO: enable sphere
a = 0.02 * base**(i - 1) * np.identity(d)
# TODO: enable non-zero minimizer, i.e., provide a shift
data_dict[i] = {
'train': [(a.astype(np.float32), .0)],
'val': [(a.astype(np.float32), .0)],
'test': [(a.astype(np.float32), .0)]
}
return data_dict, config