FS-TFP/federatedscope/autotune/pfedhpo/client.py

166 lines
5.6 KiB
Python

import os
import logging
import copy
from federatedscope.core.message import Message
from federatedscope.core.workers import Client
from federatedscope.autotune.pfedhpo.utils import *
logger = logging.getLogger(__name__)
class pFedHPOClient(Client):
def __init__(self,
ID=-1,
server_id=None,
state=-1,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
is_unseen_client=False,
*args,
**kwargs):
super(pFedHPOClient,
self).__init__(ID, server_id, state, config, data, model, device,
strategy, is_unseen_client, *args, **kwargs)
if self._cfg.hpo.pfedhpo.train_fl and \
self._cfg.hpo.pfedhpo.train_anchor:
if self._cfg.data.type == 'mini-graph-dc':
d_enc = 74
graph = True
n_label = 6
d_rff = 10
elif 'cifar' in str(self._cfg.data.type).lower():
d_enc = 32 * 32 * 3
graph = False
n_label = 10
d_rff = 6
elif 'femnist' in str(self._cfg.data.type).lower():
d_enc = 28 * 28
graph = False
n_label = 62
d_rff = 2
elif 'twitter' in str(self._cfg.data.type).lower():
d_enc = 400000
graph = False
n_label = 2
d_rff = 50
else:
raise NotImplementedError
mmd_type = 'sphere'
rff_sigma = [
127,
]
rff_sigma = [float(sig) for sig in rff_sigma]
embs = []
for sig in rff_sigma:
emb = noisy_dataset_embedding(data['train'],
d_enc,
sig,
d_rff,
device,
n_labels=n_label,
noise_factor=0.1,
mmd_type=mmd_type,
sum_frequency=25,
graph=graph)
embs.append(emb)
feats = torch.cat(embs).reshape(-1)
torch.save(
feats,
os.path.join(self._cfg.hpo.working_folder,
'client_%d_encoding.pt' % self.ID))
def _apply_hyperparams(self, hyperparams):
"""Apply the given hyperparameters
Arguments:
hyperparams (dict): keys are hyperparameter names \
and values are specific choices.
"""
cmd_args = []
for k, v in hyperparams.items():
cmd_args.append(k)
cmd_args.append(v)
self._cfg.defrost()
self._cfg.merge_from_list(cmd_args)
self._cfg.freeze(inform=False)
# self.trainer.ctx.setup_vars()
def callback_funcs_for_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
model_params, hyperparams = content["model_param"], content[
"hyper_param"]
attempt = {
'Role': 'Client #{:d}'.format(self.ID),
'Hyperparams': hyperparams
}
logger.info('-' * 30)
logger.info(attempt)
if hyperparams is not None:
self._apply_hyperparams(hyperparams)
self.trainer.update(model_params)
# self.model.load_state_dict(content)
self.state = round
sample_size, model_para_all, results = self.trainer.train()
if self._cfg.federate.share_local_model and not \
self._cfg.federate.online_aggr:
model_para_all = copy.deepcopy(model_para_all)
logger.info(
self._monitor.format_eval_res(results,
rnd=self.state,
role='Client #{}'.format(self.ID),
return_raw=True))
content = (sample_size, model_para_all, results)
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content=content))
def callback_funcs_for_evaluate(self, message: Message):
sender = message.sender
self.state = message.state
if message.content is not None:
model_params = message.content["model_param"]
self.trainer.update(model_params)
if self._cfg.finetune.before_eval:
self.trainer.finetune()
metrics = {}
for split in self._cfg.eval.split:
eval_metrics = self.trainer.evaluate(target_data_split_name=split)
for key in eval_metrics:
if self._cfg.federate.mode == 'distributed':
logger.info('Client #{:d}: (Evaluation ({:s} set) at '
'Round #{:d}) {:s} is {:.6f}'.format(
self.ID, split, self.state, key,
eval_metrics[key]))
metrics.update(**eval_metrics)
self.comm_manager.send(
Message(msg_type='metrics',
sender=self.ID,
receiver=[sender],
state=self.state,
content=metrics))