import os import logging import copy from federatedscope.core.message import Message from federatedscope.core.workers import Client from federatedscope.autotune.pfedhpo.utils import * logger = logging.getLogger(__name__) class pFedHPOClient(Client): def __init__(self, ID=-1, server_id=None, state=-1, config=None, data=None, model=None, device='cpu', strategy=None, is_unseen_client=False, *args, **kwargs): super(pFedHPOClient, self).__init__(ID, server_id, state, config, data, model, device, strategy, is_unseen_client, *args, **kwargs) if self._cfg.hpo.pfedhpo.train_fl and \ self._cfg.hpo.pfedhpo.train_anchor: if self._cfg.data.type == 'mini-graph-dc': d_enc = 74 graph = True n_label = 6 d_rff = 10 elif 'cifar' in str(self._cfg.data.type).lower(): d_enc = 32 * 32 * 3 graph = False n_label = 10 d_rff = 6 elif 'femnist' in str(self._cfg.data.type).lower(): d_enc = 28 * 28 graph = False n_label = 62 d_rff = 2 elif 'twitter' in str(self._cfg.data.type).lower(): d_enc = 400000 graph = False n_label = 2 d_rff = 50 else: raise NotImplementedError mmd_type = 'sphere' rff_sigma = [ 127, ] rff_sigma = [float(sig) for sig in rff_sigma] embs = [] for sig in rff_sigma: emb = noisy_dataset_embedding(data['train'], d_enc, sig, d_rff, device, n_labels=n_label, noise_factor=0.1, mmd_type=mmd_type, sum_frequency=25, graph=graph) embs.append(emb) feats = torch.cat(embs).reshape(-1) torch.save( feats, os.path.join(self._cfg.hpo.working_folder, 'client_%d_encoding.pt' % self.ID)) def _apply_hyperparams(self, hyperparams): """Apply the given hyperparameters Arguments: hyperparams (dict): keys are hyperparameter names \ and values are specific choices. """ cmd_args = [] for k, v in hyperparams.items(): cmd_args.append(k) cmd_args.append(v) self._cfg.defrost() self._cfg.merge_from_list(cmd_args) self._cfg.freeze(inform=False) # self.trainer.ctx.setup_vars() def callback_funcs_for_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content model_params, hyperparams = content["model_param"], content[ "hyper_param"] attempt = { 'Role': 'Client #{:d}'.format(self.ID), 'Hyperparams': hyperparams } logger.info('-' * 30) logger.info(attempt) if hyperparams is not None: self._apply_hyperparams(hyperparams) self.trainer.update(model_params) # self.model.load_state_dict(content) self.state = round sample_size, model_para_all, results = self.trainer.train() if self._cfg.federate.share_local_model and not \ self._cfg.federate.online_aggr: model_para_all = copy.deepcopy(model_para_all) logger.info( self._monitor.format_eval_res(results, rnd=self.state, role='Client #{}'.format(self.ID), return_raw=True)) content = (sample_size, model_para_all, results) self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, content=content)) def callback_funcs_for_evaluate(self, message: Message): sender = message.sender self.state = message.state if message.content is not None: model_params = message.content["model_param"] self.trainer.update(model_params) if self._cfg.finetune.before_eval: self.trainer.finetune() metrics = {} for split in self._cfg.eval.split: eval_metrics = self.trainer.evaluate(target_data_split_name=split) for key in eval_metrics: if self._cfg.federate.mode == 'distributed': logger.info('Client #{:d}: (Evaluation ({:s} set) at ' 'Round #{:d}) {:s} is {:.6f}'.format( self.ID, split, self.state, key, eval_metrics[key])) metrics.update(**eval_metrics) self.comm_manager.send( Message(msg_type='metrics', sender=self.ID, receiver=[sender], state=self.state, content=metrics))