import copy import json import os import logging from itertools import product import pickle import torch.nn import yaml from federatedscope.core.message import Message from federatedscope.core.workers import Server from federatedscope.autotune.pfedhpo.utils import * from federatedscope.autotune.utils import parse_search_space logger = logging.getLogger(__name__) class pFedHPOFLServer(Server): def __init__(self, ID=-1, state=0, config=None, data=None, model=None, client_num=5, total_round_num=10, device='cpu', strategy=None, **kwargs): # initialize action space and the policy with open(config.hpo.pfedhpo.ss, 'r') as ips: ss = yaml.load(ips, Loader=yaml.FullLoader) if next(iter(ss.keys())).startswith('arm'): # This is a flattened action space # ensure the order is unchanged ss = sorted([(int(k[3:]), v) for k, v in ss.items()], key=lambda x: x[0]) self._grid = [] self._cfsp = [[tp[1] for tp in ss]] else: # This is not a flat search space # be careful for the order self._grid = sorted(ss.keys()) self._cfsp = [ss[pn] for pn in self._grid] super(pFedHPOFLServer, self).__init__(ID, state, config, data, model, client_num, total_round_num, device, strategy, **kwargs) os.makedirs(self._cfg.hpo.working_folder, exist_ok=True) self.train_anchor = self._cfg.hpo.pfedhpo.train_anchor self.discrete = self._cfg.hpo.pfedhpo.discrete # prepare search space and bounds self._ss = parse_search_space(self._cfg.hpo.pfedhpo.ss) self.dim = len(self._ss) self.bounds = np.asarray([(0., 1.) for _ in self._ss]) self.pbounds = {} if not self.discrete: for k, v in self._ss.items(): if not (hasattr(v, 'lower') and hasattr(v, 'upper')): raise ValueError("Unsupported hyper type {}".format( type(v))) else: if v.log: l, u = np.log10(v.lower), np.log10(v.upper) else: l, u = v.lower, v.upper self.pbounds[k] = (l, u) else: for k, v in self._ss.items(): if not (hasattr(v, 'lower') and hasattr(v, 'upper')): if hasattr(v, 'choices'): self.pbounds[k] = list(v.choices) else: raise ValueError("Unsupported hyper type {}".format( type(v))) else: if v.log: l, u = np.log10(v.lower), np.log10(v.upper) else: l, u = v.lower, v.upper N_samp = 10 samp = [] for i in range(N_samp): samp.append((u - l) / N_samp * i + l) self.pbounds[k] = samp # prepare hyper-net self.client2idx = None if not self.train_anchor: hyper_enc = torch.load( os.path.join(self._cfg.hpo.working_folder, 'hyperNet_encoding.pt')) if self._cfg.data.type == 'mini-graph-dc': dim = 60 elif 'cifar' in str(self._cfg.data.type).lower(): dim = 60 elif 'femnist' in str(self._cfg.data.type).lower(): dim = 124 elif 'twitter' in str(self._cfg.data.type).lower(): dim = 100 else: raise NotImplementedError self.client_encoding = torch.ones(client_num, dim) if not self.discrete: self.HyperNet = HyperNet(encoding=self.client_encoding, num_params=len(self.pbounds), n_clients=client_num, device=self._cfg.device, var=0.01).to(self._cfg.device) else: self.HyperNet = DisHyperNet( encoding=self.client_encoding, cands=self.pbounds, n_clients=client_num, device=self._cfg.device, ).to(self._cfg.device) self.HyperNet.load_state_dict(hyper_enc['hyperNet']) self.HyperNet.eval() if not self.discrete: self.raw_params = self.HyperNet()[0].detach().cpu().numpy() else: self.logits = self.HyperNet()[0] def callback_funcs_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.sampler.change_state(sender, 'idle') # For a new round if round not in self.msg_buffer['train'].keys(): self.msg_buffer['train'][round] = dict() self.msg_buffer['train'][round][sender] = content if self._cfg.federate.online_aggr: try: self.aggregator.inc(tuple(content[0:2])) except: pass return self.check_and_move_on() def broadcast_model_para(self, msg_type='model_para', sample_client_num=-1, filter_unseen_clients=True): """ To broadcast the message to all clients or sampled clients Arguments: msg_type: 'model_para' or other user defined msg_type sample_client_num: the number of sampled clients in the broadcast behavior. And sample_client_num = -1 denotes to broadcast to all the clients. filter_unseen_clients: whether filter out the unseen clients that do not contribute to FL process by training on their local data and uploading their local model update. The splitting is useful to check participation generalization gap in [ICLR'22, What Do We Mean by Generalization in Federated Learning?] You may want to set it to be False when in evaluation stage """ if self.train_anchor: ckpt_path = os.path.join( self._cfg.hpo.working_folder, 'temp_model_round_%d.pt' % (int(self.state))) torch.save(self.model.state_dict(), ckpt_path) if filter_unseen_clients: # to filter out the unseen clients when sampling self.sampler.change_state(self.unseen_clients_id, 'unseen') if sample_client_num > 0: receiver = self.sampler.sample(size=sample_client_num) else: # broadcast to all clients receiver = list(self.comm_manager.neighbors.keys()) if msg_type == 'model_para': self.sampler.change_state(receiver, 'working') if self._noise_injector is not None and msg_type == 'model_para': # Inject noise only when broadcast parameters for model_idx_i in range(len(self.models)): num_sample_clients = [ v["num_sample"] for v in self.join_in_info.values() ] self._noise_injector(self._cfg, num_sample_clients, self.models[model_idx_i]) skip_broadcast = self._cfg.federate.method in ["local", "global"] if self.model_num > 1: model_para = [{} if skip_broadcast else model.state_dict() for model in self.models] else: model_para = {} if skip_broadcast else self.model.state_dict() if not self.client2idx: client2idx = {} _all_clients = list(self.comm_manager.neighbors.keys()) for i, k in zip(range(len(_all_clients)), _all_clients): client2idx[k] = i self.client2idx = client2idx for rcv_idx in receiver: if self.train_anchor: sampled_cfg = None else: if not self.discrete: sampled_cfg = map_value_to_param( self.raw_params[self.client2idx[rcv_idx]], self.pbounds, self._ss) else: sampled_cfg = {} for i, (k, v) in zip(range(len(self.pbounds)), self.pbounds.items()): probs = self.logits[i][self.client2idx[rcv_idx]] p = v[torch.argmax(probs).item()] if hasattr(self._ss[k], 'log') and self._ss[k].log: p = 10**p if 'int' in str(type(self._ss[k])).lower(): sampled_cfg[k] = int(p) else: sampled_cfg[k] = float(p) content = {'model_param': model_para, 'hyper_param': sampled_cfg} self.comm_manager.send( Message(msg_type=msg_type, sender=self.ID, receiver=[rcv_idx], state=self.state, content=content)) if self._cfg.federate.online_aggr: try: for idx in range(self.model_num): self.aggregators[idx].reset() except: pass if filter_unseen_clients: # restore the state of the unseen clients within sampler self.sampler.change_state(self.unseen_clients_id, 'seen') def save_res(self, feedbacks): feedbacks = {'round': self.state, 'results': feedbacks} line = str(feedbacks) + "\n" with open( os.path.join(self._cfg.hpo.working_folder, 'anchor_eval_results.log'), "a") as outfile: outfile.write(line) def check_and_move_on(self, check_eval_result=False, min_received_num=None): """ To check the message_buffer, when enough messages are receiving, trigger some events (such as perform aggregation, evaluation, and move to the next training round) """ if min_received_num is None: min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num if check_eval_result: min_received_num = len(list(self.comm_manager.neighbors.keys())) move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation if self.check_buffer(self.state, min_received_num, check_eval_result): if not check_eval_result: # in the training process mab_feedbacks = dict() # Get all the message train_msg_buffer = self.msg_buffer['train'][self.state] for model_idx in range(self.model_num): model = self.models[model_idx] aggregator = self.aggregators[model_idx] msg_list = list() for client_id in train_msg_buffer: if self.model_num == 1: msg_list.append( tuple(train_msg_buffer[client_id][0:2])) else: train_data_size, model_para_multiple = \ train_msg_buffer[client_id][0:2] msg_list.append((train_data_size, model_para_multiple[model_idx])) # collect feedbacks for updating the policy if model_idx == 0: mab_feedbacks[client_id] = train_msg_buffer[ client_id][2] # Trigger the monitor here (for training) if 'dissim' in self._cfg.eval.monitoring: from federatedscope.core.auxiliaries.utils import \ calc_blocal_dissim # TODO: fix load_state_dict B_val = calc_blocal_dissim( model.load_state_dict(strict=False), msg_list) formatted_eval_res = self._monitor.format_eval_res( B_val, rnd=self.state, role='Server #') logger.info(formatted_eval_res) # Aggregate agg_info = { 'client_feedback': msg_list, 'recover_fun': self.recover_fun } result = aggregator.aggregate(agg_info) model.load_state_dict(result, strict=False) self.state += 1 # Evaluate logger.info( 'Server: Starting evaluation at round {:d}.'.format( self.state)) self.eval() if self.state < self.total_round_num: # Move to next round of training logger.info( f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer self.msg_buffer['train'][self.state - 1].clear() self.broadcast_model_para( msg_type='model_para', sample_client_num=self.sample_client_num) else: # Final Evaluate logger.info('Server: Training is finished! Starting ' 'evaluation.') self.eval() else: # in the evaluation process # Get all the message & aggregate logger.info('-' * 30) formatted_eval_res = self.merge_eval_results_from_all_clients() self.history_results = merge_dict(self.history_results, formatted_eval_res) self.check_and_save() if self.train_anchor and self.history_results: with open( os.path.join(self._cfg.hpo.working_folder, 'anchor_eval_results.json'), 'w') as f: json.dump(self.history_results, f) else: move_on_flag = False return move_on_flag