221 lines
9.1 KiB
Python
221 lines
9.1 KiB
Python
import types
|
|
import logging
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def wrap_swa_server(server):
|
|
def check_and_move_on(self,
|
|
check_eval_result=False,
|
|
min_received_num=None):
|
|
if min_received_num is None:
|
|
if self._cfg.asyn.use:
|
|
min_received_num = self._cfg.asyn.min_received_num
|
|
else:
|
|
min_received_num = self._cfg.federate.sample_client_num
|
|
assert min_received_num <= self.sample_client_num
|
|
|
|
if check_eval_result and self._cfg.federate.mode.lower(
|
|
) == "standalone":
|
|
# in evaluation stage and standalone simulation mode, we assume
|
|
# strong synchronization that receives responses from all clients
|
|
min_received_num = len(self.comm_manager.get_neighbors().keys())
|
|
|
|
move_on_flag = True # To record whether moving to a new training
|
|
# round or finishing the evaluation
|
|
if self.check_buffer(self.state, min_received_num, check_eval_result):
|
|
if not check_eval_result:
|
|
# Receiving enough feedback in the training process
|
|
aggregated_num = self._perform_federated_aggregation()
|
|
|
|
self.state += 1
|
|
|
|
# FedSWA cache model
|
|
if self.state == self._cfg.fedswa.start_rnd:
|
|
self.swa_models_ws = [
|
|
model.state_dict() for model in self.models
|
|
]
|
|
self.swa_rnd = 1
|
|
elif self.state > \
|
|
self._cfg.fedswa.start_rnd and \
|
|
(self.state - self._cfg.fedswa.start_rnd) % \
|
|
self._cfg.fedswa.freq == 0:
|
|
logger.info(f'FedSWA cache {self.swa_rnd} models.')
|
|
for model, new_model in zip(self.swa_models_ws,
|
|
self.models):
|
|
new_model = new_model.state_dict()
|
|
for key in model.keys():
|
|
model[key] = (model[key] * self.swa_rnd +
|
|
new_model[key]) / (self.swa_rnd + 1)
|
|
self.swa_rnd += 1
|
|
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
# Evaluate
|
|
logger.info(f'Server: Starting evaluation at the end '
|
|
f'of round {self.state - 1}.')
|
|
self.eval()
|
|
|
|
if self.state < self.total_round_num:
|
|
# Move to next round of training
|
|
logger.info(
|
|
f'----------- Starting a new training round (Round '
|
|
f'#{self.state}) -------------')
|
|
# Clean the msg_buffer
|
|
self.msg_buffer['train'][self.state - 1].clear()
|
|
self.msg_buffer['train'][self.state] = dict()
|
|
self.staled_msg_buffer.clear()
|
|
# Start a new training round
|
|
self._start_new_training_round(aggregated_num)
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else:
|
|
# Receiving enough feedback in the evaluation process
|
|
self._merge_and_format_eval_results()
|
|
|
|
else:
|
|
move_on_flag = False
|
|
|
|
return move_on_flag
|
|
|
|
def eval(self):
|
|
if self._cfg.federate.make_global_eval:
|
|
for i in range(self.model_num):
|
|
trainer = self.trainers[i]
|
|
|
|
if self.eval_swa:
|
|
# Use swa model
|
|
fedavg_model_w = self.models[i].state_dict()
|
|
self.models[i].load_state_dict(self.swa_models_ws[i])
|
|
|
|
# Preform evaluation in server
|
|
metrics = {}
|
|
for split in self._cfg.eval.split:
|
|
eval_metrics = trainer.evaluate(
|
|
target_data_split_name=split)
|
|
metrics.update(**eval_metrics)
|
|
|
|
formatted_eval_res = self._monitor.format_eval_res(
|
|
metrics,
|
|
rnd=self.state,
|
|
role='Server SWA#' if self.eval_swa else 'Server #',
|
|
forms=self._cfg.eval.report,
|
|
return_raw=self._cfg.federate.make_global_eval)
|
|
|
|
if self.eval_swa:
|
|
# Restore
|
|
self.models[i].load_state_dict(fedavg_model_w)
|
|
self.best_results = formatted_eval_res['Results_raw']
|
|
else:
|
|
self._monitor.update_best_result(
|
|
self.best_results,
|
|
formatted_eval_res['Results_raw'],
|
|
results_type="server_global_eval")
|
|
self.history_results = merge_dict_of_results(
|
|
self.history_results, formatted_eval_res)
|
|
self._monitor.save_formatted_results(formatted_eval_res)
|
|
logger.info(formatted_eval_res)
|
|
self.check_and_save()
|
|
else:
|
|
if self.eval_swa:
|
|
for i in range(self.model_num):
|
|
# Use swa model
|
|
fedavg_model_w = self.models[i].state_dict()
|
|
self.models[i].load_state_dict(self.swa_models_ws[i])
|
|
# Preform evaluation in clients
|
|
self.broadcast_model_para(msg_type='evaluate',
|
|
filter_unseen_clients=False)
|
|
|
|
if self.eval_swa:
|
|
for i in range(self.model_num):
|
|
self.models[i].load_state_dict(fedavg_model_w)
|
|
|
|
def check_and_save(self):
|
|
"""
|
|
To save the results and save model after each evaluation, and check \
|
|
whether to early stop.
|
|
"""
|
|
|
|
# early stopping
|
|
if "Results_weighted_avg" in self.history_results and \
|
|
self._cfg.eval.best_res_update_round_wise_key in \
|
|
self.history_results['Results_weighted_avg']:
|
|
should_stop = self.early_stopper.track_and_check(
|
|
self.history_results['Results_weighted_avg'][
|
|
self._cfg.eval.best_res_update_round_wise_key])
|
|
elif "Results_avg" in self.history_results and \
|
|
self._cfg.eval.best_res_update_round_wise_key in \
|
|
self.history_results['Results_avg']:
|
|
should_stop = self.early_stopper.track_and_check(
|
|
self.history_results['Results_avg'][
|
|
self._cfg.eval.best_res_update_round_wise_key])
|
|
else:
|
|
should_stop = False
|
|
|
|
if should_stop:
|
|
self._monitor.global_converged()
|
|
self.comm_manager.send(
|
|
Message(
|
|
msg_type="converged",
|
|
sender=self.ID,
|
|
receiver=list(self.comm_manager.neighbors.keys()),
|
|
timestamp=self.cur_timestamp,
|
|
state=self.state,
|
|
))
|
|
self.state = self.total_round_num + 1
|
|
|
|
if should_stop or self.state >= self.total_round_num:
|
|
logger.info('Server: Final evaluation is finished! Starting '
|
|
'merging results.')
|
|
# last round or early stopped
|
|
self.save_best_results()
|
|
if not self._cfg.federate.make_global_eval:
|
|
self.save_client_eval_results()
|
|
|
|
if self.eval_swa:
|
|
self.terminate(msg_type='finish')
|
|
else:
|
|
self.eval_swa = True
|
|
logger.info('Server: Evaluation with FedSWA')
|
|
self.eval()
|
|
|
|
# Clean the clients evaluation msg buffer
|
|
if not self._cfg.federate.make_global_eval:
|
|
round = max(self.msg_buffer['eval'].keys())
|
|
self.msg_buffer['eval'][round].clear()
|
|
|
|
if self.state == self.total_round_num:
|
|
# break out the loop for distributed mode
|
|
self.state += 1
|
|
|
|
def save_best_results(self):
|
|
"""
|
|
To Save the best evaluation results.
|
|
"""
|
|
if self._cfg.federate.save_to != '':
|
|
self.aggregator.save_model(self._cfg.federate.save_to, self.state)
|
|
formatted_best_res = self._monitor.format_eval_res(
|
|
results=self.best_results,
|
|
rnd="Final",
|
|
role='Server SWA#' if self.eval_swa else 'Server #',
|
|
forms=["raw"],
|
|
return_raw=True)
|
|
logger.info(formatted_best_res)
|
|
self._monitor.save_formatted_results(formatted_best_res)
|
|
|
|
# Bind method to instance
|
|
setattr(server, 'eval_swa', False)
|
|
server.check_and_move_on = types.MethodType(check_and_move_on, server)
|
|
server.eval = types.MethodType(eval, server)
|
|
server.check_and_save = types.MethodType(check_and_save, server)
|
|
server.save_best_results = types.MethodType(save_best_results, server)
|
|
|
|
return server
|