449 lines
18 KiB
Python
449 lines
18 KiB
Python
import copy
|
|
import os
|
|
import logging
|
|
from itertools import product
|
|
import pickle
|
|
import yaml
|
|
|
|
import numpy as np
|
|
from numpy.linalg import norm
|
|
from scipy.special import logsumexp
|
|
import GPy
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers import Server
|
|
from federatedscope.core.auxiliaries.utils import merge_dict
|
|
from federatedscope.autotune.fts.utils import *
|
|
from federatedscope.autotune.utils import parse_search_space
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FTSServer(Server):
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
super(FTSServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
assert self.sample_client_num == self._cfg.federate.client_num
|
|
|
|
self.util_ts = UtilityFunction(kind="ts")
|
|
self.M = self._cfg.hpo.fts.M
|
|
|
|
# server file paths
|
|
self.all_lcoal_init_path = os.path.join(self._cfg.hpo.working_folder,
|
|
"all_localBO_inits.pkl")
|
|
self.all_local_info_path = os.path.join(self._cfg.hpo.working_folder,
|
|
"all_localBO_infos.pkl")
|
|
self.rand_feat_path = os.path.join(
|
|
self._cfg.hpo.working_folder,
|
|
"rand_feat_M_" + str(self.M) + ".pkl")
|
|
|
|
# prepare search space and bounds
|
|
self._ss = parse_search_space(self._cfg.hpo.fts.ss)
|
|
self.dim = len(self._ss)
|
|
self.bounds = np.asarray([(0., 1.) for _ in self._ss])
|
|
self.pbounds = {}
|
|
for k, v in self._ss.items():
|
|
if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
|
|
raise ValueError("Unsupported hyper type {}".format(type(v)))
|
|
else:
|
|
if v.log:
|
|
l, u = np.log10(v.lower), np.log10(v.upper)
|
|
else:
|
|
l, u = v.lower, v.upper
|
|
self.pbounds[k] = (l, u)
|
|
|
|
# distribution used to sample GP models
|
|
pt = 1 - 1 / (np.arange(self._cfg.hpo.fts.fed_bo_max_iter + 5) +
|
|
1)**2.0
|
|
pt[0] = pt[1]
|
|
self.pt = pt
|
|
self.num_other_clients = self._cfg.federate.client_num - 1
|
|
self.ws = np.ones(self.num_other_clients) / self.num_other_clients
|
|
|
|
# records for all GP models
|
|
N = self._cfg.federate.client_num + 1
|
|
self.x_max = [None for _ in range(N)]
|
|
self.y_max = [None for _ in range(N)]
|
|
self.X = [None for _ in range(N)]
|
|
self.Y = [None for _ in range(N)]
|
|
self.incumbent = [None for _ in range(N)]
|
|
self.gp = [None for _ in range(N)]
|
|
self.gp_params = [None for _ in range(N)]
|
|
self.initialized = [False for _ in range(N)]
|
|
self.res = [{
|
|
'max_value': None,
|
|
'max_param': None,
|
|
'all_values': [],
|
|
'all_params': [],
|
|
} for _ in range(N)]
|
|
self.res_paths = [
|
|
os.path.join(self._cfg.hpo.working_folder,
|
|
"result_params_%d.pkl" % cid) for cid in range(N)
|
|
]
|
|
|
|
# load or generate agent_info, agent_init, and rand_feat.
|
|
# if files already exit, load from saved files;
|
|
# else require the clients in the first round
|
|
if os.path.exists(self.all_local_info_path) and \
|
|
self._cfg.hpo.fts.allow_load_existing_info:
|
|
logger.info('Using existing rand_feat, agent_infos, '
|
|
'and agent_inits')
|
|
self.require_agent_infos = False
|
|
self.state = 1
|
|
self.random_feats = pickle.load(open(self.rand_feat_path, 'rb'))
|
|
self.all_agent_info = pickle.load(
|
|
open(self.all_local_info_path, 'rb'))
|
|
self.all_agent_init = pickle.load(
|
|
open(self.all_lcoal_init_path, 'rb'))
|
|
else:
|
|
self.require_agent_infos = True
|
|
self.random_feats = self._generate_shared_rand_feats()
|
|
self.all_agent_info = {}
|
|
self.all_agent_init = {}
|
|
|
|
# point out target clients need to be optimized by FTS
|
|
if not config.hpo.fts.target_clients:
|
|
self.target_clients = list(
|
|
range(1, self._cfg.federate.client_num + 1))
|
|
else:
|
|
self.target_clients = list(config.hpo.fts.target_clients)
|
|
|
|
def _generate_shared_rand_feats(self):
|
|
# generate shared random features
|
|
ls = self._cfg.hpo.fts.ls
|
|
v_kernel = self._cfg.hpo.fts.v_kernel
|
|
obs_noise = self._cfg.hpo.fts.obs_noise
|
|
M = self.M
|
|
|
|
s = np.random.multivariate_normal(np.zeros(self.dim),
|
|
1 / ls**2 * np.identity(self.dim), M)
|
|
b = np.random.uniform(0, 2 * np.pi, M)
|
|
random_features = {
|
|
"M": M,
|
|
"length_scale": ls,
|
|
"s": s,
|
|
"b": b,
|
|
"obs_noise": obs_noise,
|
|
"v_kernel": v_kernel
|
|
}
|
|
pickle.dump(random_features, open(self.rand_feat_path, "wb"))
|
|
|
|
return random_features
|
|
|
|
def broadcast_model_para(self,
|
|
msg_type='model_para',
|
|
sample_client_num=-1,
|
|
filter_unseen_clients=True):
|
|
"""
|
|
To broadcast the message to all clients or sampled clients
|
|
"""
|
|
if filter_unseen_clients:
|
|
# to filter out the unseen clients when sampling
|
|
self.sampler.change_state(self.unseen_clients_id, 'unseen')
|
|
|
|
if self.require_agent_infos:
|
|
# broadcast to all clients
|
|
receiver = list(self.comm_manager.neighbors.keys())
|
|
if msg_type == 'model_para':
|
|
self.sampler.change_state(receiver, 'working')
|
|
else:
|
|
receiver = list(self.target_clients)
|
|
|
|
for rcv_idx in receiver:
|
|
if self.require_agent_infos:
|
|
content = {
|
|
'require_agent_infos': True,
|
|
'random_feats': self.random_feats,
|
|
}
|
|
|
|
else:
|
|
# initialize gp models and init points
|
|
if not self.initialized[rcv_idx]:
|
|
init = self.all_agent_init[rcv_idx]
|
|
self.X[rcv_idx] = init['X']
|
|
self.Y[rcv_idx] = init['Y']
|
|
self.incumbent[rcv_idx] = np.max(self.Y[rcv_idx])
|
|
logger.info("Using pre-existing initializations "
|
|
"for client {} with {} points".format(
|
|
rcv_idx, len(self.Y[rcv_idx])))
|
|
|
|
y_max = np.max(self.Y[rcv_idx])
|
|
self.y_max[rcv_idx] = y_max
|
|
|
|
ur = unique_rows(self.X[rcv_idx])
|
|
self.gp[rcv_idx] = GPy.models.GPRegression(
|
|
self.X[rcv_idx][ur],
|
|
self.Y[rcv_idx][ur].reshape(-1, 1),
|
|
GPy.kern.RBF(input_dim=self.X[rcv_idx].shape[1],
|
|
lengthscale=self._cfg.hpo.fts.ls,
|
|
variance=self._cfg.hpo.fts.var,
|
|
ARD=False))
|
|
self.gp[rcv_idx]["Gaussian_noise.variance"][0] = \
|
|
self._cfg.hpo.fts.g_var
|
|
self._opt_gp(rcv_idx)
|
|
self.initialized[rcv_idx] = True
|
|
|
|
# sample hyper from this client's GP or others' GP
|
|
info_ts = copy.deepcopy(self.all_agent_info)
|
|
del info_ts[rcv_idx]
|
|
info_ts = list(info_ts.values())
|
|
|
|
def _try_sample_try(func):
|
|
_loop = True
|
|
while _loop:
|
|
try:
|
|
x_max, all_ucb = func(rcv_idx, self.y_max[rcv_idx],
|
|
self.state, info_ts)
|
|
_loop = False
|
|
except:
|
|
_loop = True
|
|
return x_max, all_ucb
|
|
|
|
if np.random.random() < self.pt[self.state - 1]:
|
|
x_max, all_ucb = _try_sample_try(self._sample_from_this)
|
|
else:
|
|
x_max, all_ucb = _try_sample_try(self._sample_from_others)
|
|
self.x_max[rcv_idx] = x_max
|
|
|
|
content = {
|
|
'require_agent_infos': False,
|
|
'x_max': x_max,
|
|
}
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type=msg_type,
|
|
sender=self.ID,
|
|
receiver=[rcv_idx],
|
|
state=self.state,
|
|
content=content))
|
|
|
|
if filter_unseen_clients:
|
|
# restore the state of the unseen clients within sampler
|
|
self.sampler.change_state(self.unseen_clients_id, 'seen')
|
|
|
|
def _sample_from_this(self, client, y_max, iteration, info_ts):
|
|
M_target = self._cfg.hpo.fts.M_target
|
|
|
|
ls_target = self.gp[client]["rbf.lengthscale"][0]
|
|
v_kernel = self.gp[client]["rbf.variance"][0]
|
|
obs_noise = self.gp[client]["Gaussian_noise.variance"][0]
|
|
|
|
s = np.random.multivariate_normal(
|
|
np.zeros(self.dim), 1 / (ls_target**2) * np.identity(self.dim),
|
|
M_target)
|
|
b = np.random.uniform(0, 2 * np.pi, M_target)
|
|
random_features_target = {
|
|
"M": M_target,
|
|
"length_scale": ls_target,
|
|
"s": s,
|
|
"b": b,
|
|
"obs_noise": obs_noise,
|
|
"v_kernel": v_kernel
|
|
}
|
|
|
|
Phi = np.zeros((self.X[client].shape[0], M_target))
|
|
for i, x in enumerate(self.X[client]):
|
|
x = np.squeeze(x).reshape(1, -1)
|
|
features = np.sqrt(
|
|
2 / M_target) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
|
|
features = features / np.sqrt(np.inner(features, features))
|
|
features = np.sqrt(v_kernel) * features
|
|
Phi[i, :] = features
|
|
|
|
Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M_target)
|
|
Sigma_t_inv = np.linalg.inv(Sigma_t)
|
|
nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T),
|
|
self.Y[client].reshape(-1, 1))
|
|
w_sample = np.random.multivariate_normal(np.squeeze(nu_t),
|
|
obs_noise * Sigma_t_inv, 1)
|
|
|
|
x_max, all_ucb = acq_max(ac=self.util_ts.utility,
|
|
gp=self.gp[client],
|
|
M=M_target,
|
|
N=self.num_other_clients,
|
|
gp_samples=None,
|
|
random_features=random_features_target,
|
|
info_ts=info_ts,
|
|
pt=self.pt,
|
|
ws=self.ws,
|
|
use_target_label=True,
|
|
w_sample=w_sample,
|
|
y_max=y_max,
|
|
bounds=self.bounds,
|
|
iteration=iteration)
|
|
return x_max, all_ucb
|
|
|
|
def _sample_from_others(self, client, y_max, iteration, info_ts):
|
|
agent_ind = np.arange(self.num_other_clients)
|
|
random_agent_n = np.random.choice(agent_ind, 1, p=self.ws)[0]
|
|
w_sample = info_ts[random_agent_n]
|
|
x_max, all_ucb = acq_max(ac=self.util_ts.utility,
|
|
gp=self.gp[client],
|
|
M=self.M,
|
|
N=self.num_other_clients,
|
|
gp_samples=None,
|
|
random_features=self.random_feats,
|
|
info_ts=info_ts,
|
|
pt=self.pt,
|
|
ws=self.ws,
|
|
use_target_label=False,
|
|
w_sample=w_sample,
|
|
y_max=y_max,
|
|
bounds=self.bounds,
|
|
iteration=iteration)
|
|
return x_max, all_ucb
|
|
|
|
def _opt_gp(self, client):
|
|
self.gp[client].optimize_restarts(num_restarts=10,
|
|
messages=False,
|
|
verbose=False)
|
|
self.gp_params[client] = self.gp[client].parameters
|
|
# print("---Optimized Hyper of Client %d : " % client, self.gp[client])
|
|
|
|
def callback_funcs_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
# For a new round
|
|
if round not in self.msg_buffer['train'].keys():
|
|
self.msg_buffer['train'][round] = dict()
|
|
|
|
self.msg_buffer['train'][round][sender] = content
|
|
|
|
return self.check_and_move_on()
|
|
|
|
def check_and_move_on(self,
|
|
check_eval_result=False,
|
|
min_received_num=None):
|
|
"""
|
|
To check the message_buffer. When enough messages are receiving,
|
|
some events (such as perform aggregation, evaluation, and move to
|
|
the next training round) would be triggered.
|
|
|
|
Arguments:
|
|
check_eval_result (bool): If True, check the message buffer for
|
|
evaluation; and check the message buffer for training otherwise.
|
|
"""
|
|
if min_received_num is None or check_eval_result:
|
|
min_received_num = len(self.target_clients)
|
|
if self.require_agent_infos:
|
|
min_received_num = self._cfg.federate.client_num
|
|
|
|
move_on_flag = True
|
|
|
|
# When enough messages are receiving
|
|
if self.check_buffer(self.state, min_received_num, check_eval_result):
|
|
|
|
if not check_eval_result:
|
|
# The first round is to collect clients' infomation,
|
|
# receive agent_info
|
|
if self.require_agent_infos:
|
|
for _client, _content in \
|
|
self.msg_buffer['train'][self.state].items():
|
|
assert _content['is_required_agent_info']
|
|
self.all_agent_info[_client] = _content['agent_info']
|
|
self.all_agent_init[_client] = _content['agent_init']
|
|
pickle.dump(self.all_agent_info,
|
|
open(self.all_local_info_path, "wb"))
|
|
pickle.dump(self.all_agent_init,
|
|
open(self.all_lcoal_init_path, "wb"))
|
|
self.require_agent_infos = False
|
|
|
|
# Other rounds are to update GP models, receive performance
|
|
else:
|
|
for _client, _content in \
|
|
self.msg_buffer['train'][self.state].items():
|
|
curr_y = _content['curr_y']
|
|
self.Y[_client] = np.append(self.Y[_client], curr_y)
|
|
self.X[_client] = np.vstack(
|
|
(self.X[_client], self.x_max[_client].reshape(
|
|
(1, -1))))
|
|
if self.Y[_client][-1] > self.y_max[_client]:
|
|
self.y_max[_client] = self.Y[_client][-1]
|
|
self.incumbent[_client] = self.Y[_client][-1]
|
|
ur = unique_rows(self.X[_client])
|
|
self.gp[_client].set_XY(X=self.X[_client][ur],
|
|
Y=self.Y[_client][ur].reshape(
|
|
-1, 1))
|
|
|
|
_schedule = self._cfg.hpo.fts.gp_opt_schedule
|
|
if self.state >= _schedule \
|
|
and self.state % _schedule == 0:
|
|
self._opt_gp(_client)
|
|
|
|
x_max_param = self.X[_client][self.Y[_client].argmax()]
|
|
hyper_param = x2conf(x_max_param, self.pbounds,
|
|
self._ss)
|
|
self.res[_client]['max_param'] = hyper_param
|
|
self.res[_client]['max_value'] = self.Y[_client].max()
|
|
self.res[_client]['all_values'].append(
|
|
self.Y[_client][-1].tolist())
|
|
self.res[_client]['all_params'].append(
|
|
self.X[_client][-1].tolist())
|
|
pickle.dump(self.res[_client],
|
|
open(self.res_paths[_client], 'wb'))
|
|
|
|
self.state += 1
|
|
if self.state <= self._cfg.hpo.fts.fed_bo_max_iter:
|
|
# Move to next round of training
|
|
logger.info(f'----------- GP optimizing iteration (Round '
|
|
f'#{self.state}) -------------')
|
|
# Clean the msg_buffer
|
|
self.msg_buffer['train'][self.state - 1].clear()
|
|
self.broadcast_model_para()
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else:
|
|
_results = {}
|
|
for _client, _content in \
|
|
self.msg_buffer['eval'][self.state].items():
|
|
_results[_client] = _content
|
|
|
|
self.history_results = merge_dict(self.history_results,
|
|
_results)
|
|
if self.state > self._cfg.hpo.fts.fed_bo_max_iter:
|
|
self.check_and_save()
|
|
else:
|
|
move_on_flag = False
|
|
|
|
return move_on_flag
|
|
|
|
def check_and_save(self):
|
|
"""
|
|
To save the results and save model after each evaluation
|
|
"""
|
|
# early stopping
|
|
should_stop = False
|
|
|
|
formatted_best_res = self._monitor.format_eval_res(
|
|
results=self.history_results,
|
|
rnd="Final",
|
|
role='Server #',
|
|
forms=["raw"],
|
|
return_raw=True)
|
|
logger.info('*' * 50)
|
|
logger.info(formatted_best_res)
|
|
self._monitor.save_formatted_results(formatted_best_res)
|
|
|
|
if should_stop:
|
|
self.state = self.total_round_num + 1
|
|
|
|
logger.info('*' * 50)
|
|
if self.state == self.total_round_num:
|
|
# break out the loop for distributed mode
|
|
self.state += 1
|