FS-TFP/federatedscope/core/workers/wrapper/fedswa.py

221 lines
9.1 KiB
Python

import types
import logging
from federatedscope.core.message import Message
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
logger = logging.getLogger(__name__)
def wrap_swa_server(server):
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
if min_received_num is None:
if self._cfg.asyn.use:
min_received_num = self._cfg.asyn.min_received_num
else:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num
if check_eval_result and self._cfg.federate.mode.lower(
) == "standalone":
# in evaluation stage and standalone simulation mode, we assume
# strong synchronization that receives responses from all clients
min_received_num = len(self.comm_manager.get_neighbors().keys())
move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if not check_eval_result:
# Receiving enough feedback in the training process
aggregated_num = self._perform_federated_aggregation()
self.state += 1
# FedSWA cache model
if self.state == self._cfg.fedswa.start_rnd:
self.swa_models_ws = [
model.state_dict() for model in self.models
]
self.swa_rnd = 1
elif self.state > \
self._cfg.fedswa.start_rnd and \
(self.state - self._cfg.fedswa.start_rnd) % \
self._cfg.fedswa.freq == 0:
logger.info(f'FedSWA cache {self.swa_rnd} models.')
for model, new_model in zip(self.swa_models_ws,
self.models):
new_model = new_model.state_dict()
for key in model.keys():
model[key] = (model[key] * self.swa_rnd +
new_model[key]) / (self.swa_rnd + 1)
self.swa_rnd += 1
if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(f'Server: Starting evaluation at the end '
f'of round {self.state - 1}.')
self.eval()
if self.state < self.total_round_num:
# Move to next round of training
logger.info(
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.msg_buffer['train'][self.state] = dict()
self.staled_msg_buffer.clear()
# Start a new training round
self._start_new_training_round(aggregated_num)
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()
else:
# Receiving enough feedback in the evaluation process
self._merge_and_format_eval_results()
else:
move_on_flag = False
return move_on_flag
def eval(self):
if self._cfg.federate.make_global_eval:
for i in range(self.model_num):
trainer = self.trainers[i]
if self.eval_swa:
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])
# Preform evaluation in server
metrics = {}
for split in self._cfg.eval.split:
eval_metrics = trainer.evaluate(
target_data_split_name=split)
metrics.update(**eval_metrics)
formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state,
role='Server SWA#' if self.eval_swa else 'Server #',
forms=self._cfg.eval.report,
return_raw=self._cfg.federate.make_global_eval)
if self.eval_swa:
# Restore
self.models[i].load_state_dict(fedavg_model_w)
self.best_results = formatted_eval_res['Results_raw']
else:
self._monitor.update_best_result(
self.best_results,
formatted_eval_res['Results_raw'],
results_type="server_global_eval")
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self._monitor.save_formatted_results(formatted_eval_res)
logger.info(formatted_eval_res)
self.check_and_save()
else:
if self.eval_swa:
for i in range(self.model_num):
# Use swa model
fedavg_model_w = self.models[i].state_dict()
self.models[i].load_state_dict(self.swa_models_ws[i])
# Preform evaluation in clients
self.broadcast_model_para(msg_type='evaluate',
filter_unseen_clients=False)
if self.eval_swa:
for i in range(self.model_num):
self.models[i].load_state_dict(fedavg_model_w)
def check_and_save(self):
"""
To save the results and save model after each evaluation, and check \
whether to early stop.
"""
# early stopping
if "Results_weighted_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_weighted_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_weighted_avg'][
self._cfg.eval.best_res_update_round_wise_key])
elif "Results_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_avg'][
self._cfg.eval.best_res_update_round_wise_key])
else:
should_stop = False
if should_stop:
self._monitor.global_converged()
self.comm_manager.send(
Message(
msg_type="converged",
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
timestamp=self.cur_timestamp,
state=self.state,
))
self.state = self.total_round_num + 1
if should_stop or self.state >= self.total_round_num:
logger.info('Server: Final evaluation is finished! Starting '
'merging results.')
# last round or early stopped
self.save_best_results()
if not self._cfg.federate.make_global_eval:
self.save_client_eval_results()
if self.eval_swa:
self.terminate(msg_type='finish')
else:
self.eval_swa = True
logger.info('Server: Evaluation with FedSWA')
self.eval()
# Clean the clients evaluation msg buffer
if not self._cfg.federate.make_global_eval:
round = max(self.msg_buffer['eval'].keys())
self.msg_buffer['eval'][round].clear()
if self.state == self.total_round_num:
# break out the loop for distributed mode
self.state += 1
def save_best_results(self):
"""
To Save the best evaluation results.
"""
if self._cfg.federate.save_to != '':
self.aggregator.save_model(self._cfg.federate.save_to, self.state)
formatted_best_res = self._monitor.format_eval_res(
results=self.best_results,
rnd="Final",
role='Server SWA#' if self.eval_swa else 'Server #',
forms=["raw"],
return_raw=True)
logger.info(formatted_best_res)
self._monitor.save_formatted_results(formatted_best_res)
# Bind method to instance
setattr(server, 'eval_swa', False)
server.check_and_move_on = types.MethodType(check_and_move_on, server)
server.eval = types.MethodType(eval, server)
server.check_and_save = types.MethodType(check_and_save, server)
server.save_best_results = types.MethodType(save_best_results, server)
return server