20 lines
670 B
Python
20 lines
670 B
Python
import numpy as np
|
|
|
|
|
|
def load_quadratic_dataset(config):
|
|
data_dict = dict()
|
|
d = config.data.quadratic.dim
|
|
base = np.exp(
|
|
np.log(config.data.quadratic.max_curv / config.data.quadratic.min_curv)
|
|
/ (config.federate.client_num - 1))
|
|
for i in range(1, 1 + config.federate.client_num):
|
|
# TODO: enable sphere
|
|
a = 0.02 * base**(i - 1) * np.identity(d)
|
|
# TODO: enable non-zero minimizer, i.e., provide a shift
|
|
data_dict[i] = {
|
|
'train': [(a.astype(np.float32), .0)],
|
|
'val': [(a.astype(np.float32), .0)],
|
|
'test': [(a.astype(np.float32), .0)]
|
|
}
|
|
return data_dict, config
|