239 lines
9.1 KiB
Python
239 lines
9.1 KiB
Python
import os
|
|
import logging
|
|
import json
|
|
import copy
|
|
import pickle
|
|
import numpy as np
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers import Client
|
|
|
|
from federatedscope.autotune.fts.utils import *
|
|
from federatedscope.autotune.utils import parse_search_space
|
|
from federatedscope.core.auxiliaries.trainer_builder import get_trainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FTSClient(Client):
|
|
def __init__(self,
|
|
ID=-1,
|
|
server_id=None,
|
|
state=-1,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
device='cpu',
|
|
strategy=None,
|
|
is_unseen_client=False,
|
|
*args,
|
|
**kwargs):
|
|
super(FTSClient,
|
|
self).__init__(ID, server_id, state, config, data, model, device,
|
|
strategy, is_unseen_client, *args, **kwargs)
|
|
self.data = data
|
|
self.model = model
|
|
self.device = device
|
|
self._diff = config.hpo.fts.diff
|
|
self._init_model = copy.deepcopy(model)
|
|
|
|
# local file paths
|
|
self.local_bo_path = os.path.join(self._cfg.hpo.working_folder,
|
|
"local_bo_" + str(self.ID) + ".pkl")
|
|
self.local_init_path = os.path.join(
|
|
self._cfg.hpo.working_folder,
|
|
"local_init_" + str(self.ID) + ".pkl")
|
|
self.local_info_path = os.path.join(
|
|
self._cfg.hpo.working_folder, "local_info_" + str(self.ID) +
|
|
"_M_" + str(self._cfg.hpo.fts.M) + ".pkl")
|
|
|
|
# prepare search space and bounds
|
|
self._ss = parse_search_space(self._cfg.hpo.fts.ss)
|
|
self.dim = len(self._ss)
|
|
self.bounds = np.asarray([(0., 1.) for _ in self._ss])
|
|
self.pbounds = {}
|
|
for k, v in self._ss.items():
|
|
if not (hasattr(v, 'lower') and hasattr(v, 'upper')):
|
|
raise ValueError("Unsupported hyper type {}".format(type(v)))
|
|
else:
|
|
if v.log:
|
|
l, u = np.log10(v.lower), np.log10(v.upper)
|
|
else:
|
|
l, u = v.lower, v.upper
|
|
self.pbounds[k] = (l, u)
|
|
|
|
def _apply_hyperparams(self, hyperparams):
|
|
"""Apply the given hyperparameters
|
|
Arguments:
|
|
hyperparams (dict): keys are hyperparameter names \
|
|
and values are specific choices.
|
|
"""
|
|
|
|
cmd_args = []
|
|
for k, v in hyperparams.items():
|
|
cmd_args.append(k)
|
|
cmd_args.append(v)
|
|
|
|
self._cfg.defrost()
|
|
self._cfg.merge_from_list(cmd_args)
|
|
self._cfg.freeze(inform=False)
|
|
|
|
self.trainer.ctx.setup_vars()
|
|
|
|
def _get_new_trainer(self):
|
|
self.model = copy.deepcopy(self._init_model)
|
|
self.trainer = get_trainer(model=self.model,
|
|
data=self.data,
|
|
device=self.device,
|
|
config=self._cfg,
|
|
is_attacker=self.is_attacker,
|
|
monitor=self._monitor)
|
|
|
|
def _obj_func(self, x, return_eval=False):
|
|
"""
|
|
Run local evaluation, return some metric to maximize (e.g. val_acc)
|
|
"""
|
|
self._get_new_trainer()
|
|
|
|
baseline = 5.0
|
|
hyperparams = x2conf(x, self.pbounds, self._ss)
|
|
self._apply_hyperparams(hyperparams)
|
|
|
|
results_before = self.trainer.evaluate('val')
|
|
for _ in range(self._cfg.hpo.fts.local_bo_epochs):
|
|
sample_size, model_para_all, results = self.trainer.train()
|
|
results_after = self.trainer.evaluate('val')
|
|
|
|
if self._diff:
|
|
res = results_before['val_avg_loss'] \
|
|
- results_after['val_avg_loss']
|
|
else:
|
|
res = baseline - results_after['val_avg_loss']
|
|
if return_eval:
|
|
return res, results_after
|
|
else:
|
|
return res
|
|
|
|
def _generate_agent_info(self, rand_feats):
|
|
logger.info(
|
|
('-' * 20, ' generate info on clinet %d ' % self.ID, '_' * 20))
|
|
v_kernel = self._cfg.hpo.fts.v_kernel
|
|
obs_noise = self._cfg.hpo.fts.obs_noise
|
|
M = self._cfg.hpo.fts.M
|
|
M_target = self._cfg.hpo.fts.M_target
|
|
|
|
# run standard BO locally
|
|
max_iter = self._cfg.hpo.fts.local_bo_max_iter
|
|
gp_opt_schedule = self._cfg.hpo.fts.gp_opt_schedule
|
|
pt = np.ones(max_iter + 5)
|
|
LocalBO(cid=self.ID,
|
|
f=self._obj_func,
|
|
bounds=self.bounds,
|
|
keys=list(self.pbounds.keys()),
|
|
gp_opt_schedule=gp_opt_schedule,
|
|
use_init=None,
|
|
log_file=self.local_bo_path,
|
|
save_init=True,
|
|
save_init_file=self.local_init_path,
|
|
pt=pt,
|
|
P_N=None,
|
|
ls=self._cfg.hpo.fts.ls,
|
|
var=self._cfg.hpo.fts.var,
|
|
g_var=self._cfg.hpo.fts.g_var,
|
|
N=self._cfg.federate.client_num - 1,
|
|
M_target=M_target).maximize(n_iter=max_iter, init_points=3)
|
|
|
|
# generate local RFF information
|
|
res = pickle.load(open(self.local_bo_path, "rb"))
|
|
ys = np.array(res["all"]["values"]).reshape(-1, 1)
|
|
params = np.array(res["all"]["params"])
|
|
xs = np.array(params)
|
|
xs, ys = xs[:max_iter], ys[:max_iter]
|
|
Phi = np.zeros((xs.shape[0], M))
|
|
|
|
s, b = rand_feats["s"], rand_feats["b"]
|
|
for i, x in enumerate(xs):
|
|
x = np.squeeze(x).reshape(1, -1)
|
|
features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b)
|
|
features = features / np.sqrt(np.inner(features, features))
|
|
features = np.sqrt(v_kernel) * features
|
|
Phi[i, :] = features
|
|
|
|
Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M)
|
|
Sigma_t_inv = np.linalg.inv(Sigma_t)
|
|
nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), ys)
|
|
w_samples = np.random.multivariate_normal(np.squeeze(nu_t),
|
|
obs_noise * Sigma_t_inv, 1)
|
|
pickle.dump(w_samples, open(self.local_info_path, "wb"))
|
|
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
require_agent_infos = content['require_agent_infos']
|
|
|
|
# generate local info and init then send them to server
|
|
if require_agent_infos:
|
|
rand_feat = content['random_feats']
|
|
self._generate_agent_info(rand_feat)
|
|
agent_info = pickle.load(open(self.local_info_path, "rb"))
|
|
agent_init = pickle.load(open(self.local_init_path, "rb"))
|
|
content = {
|
|
'is_required_agent_info': True,
|
|
'agent_info': agent_info,
|
|
'agent_init': agent_init,
|
|
}
|
|
|
|
# local run on given hyper-param and return performance
|
|
else:
|
|
x_max = content['x_max']
|
|
curr_y, eval_res = self._obj_func(x_max, return_eval=True)
|
|
content = {
|
|
'is_required_agent_info': False,
|
|
'curr_y': curr_y,
|
|
}
|
|
hyper_param = x2conf(x_max, self.pbounds, self._ss)
|
|
logger.info('{Client: %d, ' % self.ID +
|
|
'GP_opt_iter: %d, ' % round + 'Params: ' +
|
|
str(hyper_param) + ', Perform: ' + str(curr_y) + '}')
|
|
logger.info(
|
|
self._monitor.format_eval_res(eval_res,
|
|
rnd=self.state,
|
|
role='Client #{}'.format(
|
|
self.ID),
|
|
return_raw=True))
|
|
|
|
self.state = round
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=content))
|
|
|
|
def callback_funcs_for_evaluate(self, message: Message):
|
|
round, sender, content = \
|
|
message.state, message.sender, message.content
|
|
require_agent_infos = content['require_agent_infos']
|
|
assert not require_agent_infos, \
|
|
"Can not evaluate when there is no agents' information"
|
|
|
|
self.state = message.state
|
|
self._obj_func(content['x_max'])
|
|
|
|
metrics = {}
|
|
for split in self._cfg.eval.split:
|
|
eval_metrics = self.trainer.evaluate(target_data_split_name=split)
|
|
for key in eval_metrics:
|
|
if self._cfg.federate.mode == 'distributed':
|
|
logger.info('Client #{:d}: (Evaluation ({:s} set) at '
|
|
'Round #{:d}) {:s} is {:.6f}'.format(
|
|
self.ID, split, self.state, key,
|
|
eval_metrics[key]))
|
|
metrics.update(**eval_metrics)
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='metrics',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=metrics))
|