166 lines
5.6 KiB
Python
166 lines
5.6 KiB
Python
import os
|
|
import logging
|
|
import copy
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers import Client
|
|
from federatedscope.autotune.pfedhpo.utils import *
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class pFedHPOClient(Client):
|
|
def __init__(self,
|
|
ID=-1,
|
|
server_id=None,
|
|
state=-1,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
device='cpu',
|
|
strategy=None,
|
|
is_unseen_client=False,
|
|
*args,
|
|
**kwargs):
|
|
|
|
super(pFedHPOClient,
|
|
self).__init__(ID, server_id, state, config, data, model, device,
|
|
strategy, is_unseen_client, *args, **kwargs)
|
|
|
|
if self._cfg.hpo.pfedhpo.train_fl and \
|
|
self._cfg.hpo.pfedhpo.train_anchor:
|
|
|
|
if self._cfg.data.type == 'mini-graph-dc':
|
|
d_enc = 74
|
|
graph = True
|
|
n_label = 6
|
|
d_rff = 10
|
|
|
|
elif 'cifar' in str(self._cfg.data.type).lower():
|
|
d_enc = 32 * 32 * 3
|
|
graph = False
|
|
n_label = 10
|
|
d_rff = 6
|
|
|
|
elif 'femnist' in str(self._cfg.data.type).lower():
|
|
d_enc = 28 * 28
|
|
graph = False
|
|
n_label = 62
|
|
d_rff = 2
|
|
|
|
elif 'twitter' in str(self._cfg.data.type).lower():
|
|
d_enc = 400000
|
|
graph = False
|
|
n_label = 2
|
|
d_rff = 50
|
|
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
mmd_type = 'sphere'
|
|
rff_sigma = [
|
|
127,
|
|
]
|
|
rff_sigma = [float(sig) for sig in rff_sigma]
|
|
|
|
embs = []
|
|
for sig in rff_sigma:
|
|
emb = noisy_dataset_embedding(data['train'],
|
|
d_enc,
|
|
sig,
|
|
d_rff,
|
|
device,
|
|
n_labels=n_label,
|
|
noise_factor=0.1,
|
|
mmd_type=mmd_type,
|
|
sum_frequency=25,
|
|
graph=graph)
|
|
embs.append(emb)
|
|
feats = torch.cat(embs).reshape(-1)
|
|
|
|
torch.save(
|
|
feats,
|
|
os.path.join(self._cfg.hpo.working_folder,
|
|
'client_%d_encoding.pt' % self.ID))
|
|
|
|
def _apply_hyperparams(self, hyperparams):
|
|
"""Apply the given hyperparameters
|
|
Arguments:
|
|
hyperparams (dict): keys are hyperparameter names \
|
|
and values are specific choices.
|
|
"""
|
|
|
|
cmd_args = []
|
|
for k, v in hyperparams.items():
|
|
cmd_args.append(k)
|
|
cmd_args.append(v)
|
|
|
|
self._cfg.defrost()
|
|
self._cfg.merge_from_list(cmd_args)
|
|
self._cfg.freeze(inform=False)
|
|
|
|
# self.trainer.ctx.setup_vars()
|
|
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
model_params, hyperparams = content["model_param"], content[
|
|
"hyper_param"]
|
|
attempt = {
|
|
'Role': 'Client #{:d}'.format(self.ID),
|
|
'Hyperparams': hyperparams
|
|
}
|
|
logger.info('-' * 30)
|
|
logger.info(attempt)
|
|
|
|
if hyperparams is not None:
|
|
self._apply_hyperparams(hyperparams)
|
|
|
|
self.trainer.update(model_params)
|
|
|
|
# self.model.load_state_dict(content)
|
|
self.state = round
|
|
sample_size, model_para_all, results = self.trainer.train()
|
|
|
|
if self._cfg.federate.share_local_model and not \
|
|
self._cfg.federate.online_aggr:
|
|
model_para_all = copy.deepcopy(model_para_all)
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state,
|
|
role='Client #{}'.format(self.ID),
|
|
return_raw=True))
|
|
|
|
content = (sample_size, model_para_all, results)
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=content))
|
|
|
|
def callback_funcs_for_evaluate(self, message: Message):
|
|
sender = message.sender
|
|
self.state = message.state
|
|
if message.content is not None:
|
|
model_params = message.content["model_param"]
|
|
self.trainer.update(model_params)
|
|
if self._cfg.finetune.before_eval:
|
|
self.trainer.finetune()
|
|
metrics = {}
|
|
for split in self._cfg.eval.split:
|
|
eval_metrics = self.trainer.evaluate(target_data_split_name=split)
|
|
for key in eval_metrics:
|
|
|
|
if self._cfg.federate.mode == 'distributed':
|
|
logger.info('Client #{:d}: (Evaluation ({:s} set) at '
|
|
'Round #{:d}) {:s} is {:.6f}'.format(
|
|
self.ID, split, self.state, key,
|
|
eval_metrics[key]))
|
|
metrics.update(**eval_metrics)
|
|
self.comm_manager.send(
|
|
Message(msg_type='metrics',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=metrics))
|