import types import logging from federatedscope.core.message import Message from federatedscope.core.auxiliaries.utils import merge_dict_of_results logger = logging.getLogger(__name__) def wrap_swa_server(server): def check_and_move_on(self, check_eval_result=False, min_received_num=None): if min_received_num is None: if self._cfg.asyn.use: min_received_num = self._cfg.asyn.min_received_num else: min_received_num = self._cfg.federate.sample_client_num assert min_received_num <= self.sample_client_num if check_eval_result and self._cfg.federate.mode.lower( ) == "standalone": # in evaluation stage and standalone simulation mode, we assume # strong synchronization that receives responses from all clients min_received_num = len(self.comm_manager.get_neighbors().keys()) move_on_flag = True # To record whether moving to a new training # round or finishing the evaluation if self.check_buffer(self.state, min_received_num, check_eval_result): if not check_eval_result: # Receiving enough feedback in the training process aggregated_num = self._perform_federated_aggregation() self.state += 1 # FedSWA cache model if self.state == self._cfg.fedswa.start_rnd: self.swa_models_ws = [ model.state_dict() for model in self.models ] self.swa_rnd = 1 elif self.state > \ self._cfg.fedswa.start_rnd and \ (self.state - self._cfg.fedswa.start_rnd) % \ self._cfg.fedswa.freq == 0: logger.info(f'FedSWA cache {self.swa_rnd} models.') for model, new_model in zip(self.swa_models_ws, self.models): new_model = new_model.state_dict() for key in model.keys(): model[key] = (model[key] * self.swa_rnd + new_model[key]) / (self.swa_rnd + 1) self.swa_rnd += 1 if self.state % self._cfg.eval.freq == 0 and self.state != \ self.total_round_num: # Evaluate logger.info(f'Server: Starting evaluation at the end ' f'of round {self.state - 1}.') self.eval() if self.state < self.total_round_num: # Move to next round of training logger.info( f'----------- Starting a new training round (Round ' f'#{self.state}) -------------') # Clean the msg_buffer self.msg_buffer['train'][self.state - 1].clear() self.msg_buffer['train'][self.state] = dict() self.staled_msg_buffer.clear() # Start a new training round self._start_new_training_round(aggregated_num) else: # Final Evaluate logger.info('Server: Training is finished! Starting ' 'evaluation.') self.eval() else: # Receiving enough feedback in the evaluation process self._merge_and_format_eval_results() else: move_on_flag = False return move_on_flag def eval(self): if self._cfg.federate.make_global_eval: for i in range(self.model_num): trainer = self.trainers[i] if self.eval_swa: # Use swa model fedavg_model_w = self.models[i].state_dict() self.models[i].load_state_dict(self.swa_models_ws[i]) # Preform evaluation in server metrics = {} for split in self._cfg.eval.split: eval_metrics = trainer.evaluate( target_data_split_name=split) metrics.update(**eval_metrics) formatted_eval_res = self._monitor.format_eval_res( metrics, rnd=self.state, role='Server SWA#' if self.eval_swa else 'Server #', forms=self._cfg.eval.report, return_raw=self._cfg.federate.make_global_eval) if self.eval_swa: # Restore self.models[i].load_state_dict(fedavg_model_w) self.best_results = formatted_eval_res['Results_raw'] else: self._monitor.update_best_result( self.best_results, formatted_eval_res['Results_raw'], results_type="server_global_eval") self.history_results = merge_dict_of_results( self.history_results, formatted_eval_res) self._monitor.save_formatted_results(formatted_eval_res) logger.info(formatted_eval_res) self.check_and_save() else: if self.eval_swa: for i in range(self.model_num): # Use swa model fedavg_model_w = self.models[i].state_dict() self.models[i].load_state_dict(self.swa_models_ws[i]) # Preform evaluation in clients self.broadcast_model_para(msg_type='evaluate', filter_unseen_clients=False) if self.eval_swa: for i in range(self.model_num): self.models[i].load_state_dict(fedavg_model_w) def check_and_save(self): """ To save the results and save model after each evaluation, and check \ whether to early stop. """ # early stopping if "Results_weighted_avg" in self.history_results and \ self._cfg.eval.best_res_update_round_wise_key in \ self.history_results['Results_weighted_avg']: should_stop = self.early_stopper.track_and_check( self.history_results['Results_weighted_avg'][ self._cfg.eval.best_res_update_round_wise_key]) elif "Results_avg" in self.history_results and \ self._cfg.eval.best_res_update_round_wise_key in \ self.history_results['Results_avg']: should_stop = self.early_stopper.track_and_check( self.history_results['Results_avg'][ self._cfg.eval.best_res_update_round_wise_key]) else: should_stop = False if should_stop: self._monitor.global_converged() self.comm_manager.send( Message( msg_type="converged", sender=self.ID, receiver=list(self.comm_manager.neighbors.keys()), timestamp=self.cur_timestamp, state=self.state, )) self.state = self.total_round_num + 1 if should_stop or self.state >= self.total_round_num: logger.info('Server: Final evaluation is finished! Starting ' 'merging results.') # last round or early stopped self.save_best_results() if not self._cfg.federate.make_global_eval: self.save_client_eval_results() if self.eval_swa: self.terminate(msg_type='finish') else: self.eval_swa = True logger.info('Server: Evaluation with FedSWA') self.eval() # Clean the clients evaluation msg buffer if not self._cfg.federate.make_global_eval: round = max(self.msg_buffer['eval'].keys()) self.msg_buffer['eval'][round].clear() if self.state == self.total_round_num: # break out the loop for distributed mode self.state += 1 def save_best_results(self): """ To Save the best evaluation results. """ if self._cfg.federate.save_to != '': self.aggregator.save_model(self._cfg.federate.save_to, self.state) formatted_best_res = self._monitor.format_eval_res( results=self.best_results, rnd="Final", role='Server SWA#' if self.eval_swa else 'Server #', forms=["raw"], return_raw=True) logger.info(formatted_best_res) self._monitor.save_formatted_results(formatted_best_res) # Bind method to instance setattr(server, 'eval_swa', False) server.check_and_move_on = types.MethodType(check_and_move_on, server) server.eval = types.MethodType(eval, server) server.check_and_save = types.MethodType(check_and_save, server) server.save_best_results = types.MethodType(save_best_results, server) return server