import os import logging import json import copy import pickle import numpy as np from federatedscope.core.message import Message from federatedscope.core.workers import Client from federatedscope.autotune.fts.utils import * from federatedscope.autotune.utils import parse_search_space from federatedscope.core.auxiliaries.trainer_builder import get_trainer logger = logging.getLogger(__name__) class FTSClient(Client): def __init__(self, ID=-1, server_id=None, state=-1, config=None, data=None, model=None, device='cpu', strategy=None, is_unseen_client=False, *args, **kwargs): super(FTSClient, self).__init__(ID, server_id, state, config, data, model, device, strategy, is_unseen_client, *args, **kwargs) self.data = data self.model = model self.device = device self._diff = config.hpo.fts.diff self._init_model = copy.deepcopy(model) # local file paths self.local_bo_path = os.path.join(self._cfg.hpo.working_folder, "local_bo_" + str(self.ID) + ".pkl") self.local_init_path = os.path.join( self._cfg.hpo.working_folder, "local_init_" + str(self.ID) + ".pkl") self.local_info_path = os.path.join( self._cfg.hpo.working_folder, "local_info_" + str(self.ID) + "_M_" + str(self._cfg.hpo.fts.M) + ".pkl") # prepare search space and bounds self._ss = parse_search_space(self._cfg.hpo.fts.ss) self.dim = len(self._ss) self.bounds = np.asarray([(0., 1.) for _ in self._ss]) self.pbounds = {} for k, v in self._ss.items(): if not (hasattr(v, 'lower') and hasattr(v, 'upper')): raise ValueError("Unsupported hyper type {}".format(type(v))) else: if v.log: l, u = np.log10(v.lower), np.log10(v.upper) else: l, u = v.lower, v.upper self.pbounds[k] = (l, u) def _apply_hyperparams(self, hyperparams): """Apply the given hyperparameters Arguments: hyperparams (dict): keys are hyperparameter names \ and values are specific choices. """ cmd_args = [] for k, v in hyperparams.items(): cmd_args.append(k) cmd_args.append(v) self._cfg.defrost() self._cfg.merge_from_list(cmd_args) self._cfg.freeze(inform=False) self.trainer.ctx.setup_vars() def _get_new_trainer(self): self.model = copy.deepcopy(self._init_model) self.trainer = get_trainer(model=self.model, data=self.data, device=self.device, config=self._cfg, is_attacker=self.is_attacker, monitor=self._monitor) def _obj_func(self, x, return_eval=False): """ Run local evaluation, return some metric to maximize (e.g. val_acc) """ self._get_new_trainer() baseline = 5.0 hyperparams = x2conf(x, self.pbounds, self._ss) self._apply_hyperparams(hyperparams) results_before = self.trainer.evaluate('val') for _ in range(self._cfg.hpo.fts.local_bo_epochs): sample_size, model_para_all, results = self.trainer.train() results_after = self.trainer.evaluate('val') if self._diff: res = results_before['val_avg_loss'] \ - results_after['val_avg_loss'] else: res = baseline - results_after['val_avg_loss'] if return_eval: return res, results_after else: return res def _generate_agent_info(self, rand_feats): logger.info( ('-' * 20, ' generate info on clinet %d ' % self.ID, '_' * 20)) v_kernel = self._cfg.hpo.fts.v_kernel obs_noise = self._cfg.hpo.fts.obs_noise M = self._cfg.hpo.fts.M M_target = self._cfg.hpo.fts.M_target # run standard BO locally max_iter = self._cfg.hpo.fts.local_bo_max_iter gp_opt_schedule = self._cfg.hpo.fts.gp_opt_schedule pt = np.ones(max_iter + 5) LocalBO(cid=self.ID, f=self._obj_func, bounds=self.bounds, keys=list(self.pbounds.keys()), gp_opt_schedule=gp_opt_schedule, use_init=None, log_file=self.local_bo_path, save_init=True, save_init_file=self.local_init_path, pt=pt, P_N=None, ls=self._cfg.hpo.fts.ls, var=self._cfg.hpo.fts.var, g_var=self._cfg.hpo.fts.g_var, N=self._cfg.federate.client_num - 1, M_target=M_target).maximize(n_iter=max_iter, init_points=3) # generate local RFF information res = pickle.load(open(self.local_bo_path, "rb")) ys = np.array(res["all"]["values"]).reshape(-1, 1) params = np.array(res["all"]["params"]) xs = np.array(params) xs, ys = xs[:max_iter], ys[:max_iter] Phi = np.zeros((xs.shape[0], M)) s, b = rand_feats["s"], rand_feats["b"] for i, x in enumerate(xs): x = np.squeeze(x).reshape(1, -1) features = np.sqrt(2 / M) * np.cos(np.squeeze(np.dot(x, s.T)) + b) features = features / np.sqrt(np.inner(features, features)) features = np.sqrt(v_kernel) * features Phi[i, :] = features Sigma_t = np.dot(Phi.T, Phi) + obs_noise * np.identity(M) Sigma_t_inv = np.linalg.inv(Sigma_t) nu_t = np.dot(np.dot(Sigma_t_inv, Phi.T), ys) w_samples = np.random.multivariate_normal(np.squeeze(nu_t), obs_noise * Sigma_t_inv, 1) pickle.dump(w_samples, open(self.local_info_path, "wb")) def callback_funcs_for_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content require_agent_infos = content['require_agent_infos'] # generate local info and init then send them to server if require_agent_infos: rand_feat = content['random_feats'] self._generate_agent_info(rand_feat) agent_info = pickle.load(open(self.local_info_path, "rb")) agent_init = pickle.load(open(self.local_init_path, "rb")) content = { 'is_required_agent_info': True, 'agent_info': agent_info, 'agent_init': agent_init, } # local run on given hyper-param and return performance else: x_max = content['x_max'] curr_y, eval_res = self._obj_func(x_max, return_eval=True) content = { 'is_required_agent_info': False, 'curr_y': curr_y, } hyper_param = x2conf(x_max, self.pbounds, self._ss) logger.info('{Client: %d, ' % self.ID + 'GP_opt_iter: %d, ' % round + 'Params: ' + str(hyper_param) + ', Perform: ' + str(curr_y) + '}') logger.info( self._monitor.format_eval_res(eval_res, rnd=self.state, role='Client #{}'.format( self.ID), return_raw=True)) self.state = round self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, content=content)) def callback_funcs_for_evaluate(self, message: Message): round, sender, content = \ message.state, message.sender, message.content require_agent_infos = content['require_agent_infos'] assert not require_agent_infos, \ "Can not evaluate when there is no agents' information" self.state = message.state self._obj_func(content['x_max']) metrics = {} for split in self._cfg.eval.split: eval_metrics = self.trainer.evaluate(target_data_split_name=split) for key in eval_metrics: if self._cfg.federate.mode == 'distributed': logger.info('Client #{:d}: (Evaluation ({:s} set) at ' 'Round #{:d}) {:s} is {:.6f}'.format( self.ID, split, self.state, key, eval_metrics[key])) metrics.update(**eval_metrics) self.comm_manager.send( Message(msg_type='metrics', sender=self.ID, receiver=[sender], state=self.state, content=metrics))