import logging from federatedscope.core.message import Message from federatedscope.core.auxiliaries.utils import merge_dict_of_results from federatedscope.core.workers import Client from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor logger = logging.getLogger(__name__) class ATCClient(Client): def __init__(self, ID=-1, server_id=None, state=-1, config=None, data=None, model=None, device='cpu', strategy=None, *args, **kwargs): super().__init__( ID=ID, server_id=server_id, state=state, config=config, data=data, model=model, device=device, strategy=strategy, *args, **kwargs, ) self.use_contrastive_loss = self._cfg.model.use_contrastive_loss self.trainer.update_stat(self.ID) def _copy_contrast_monitor(self, raw_monitor): monitor = ContrastiveMonitor() for var in vars(monitor): getattr(monitor, 'update_{}'.format(var))(getattr(raw_monitor, var)) return monitor def callback_funcs_for_model_para(self, message: Message): round, sender, content = message.state, message.sender, message.content self.state = round if not self.use_contrastive_loss: self.trainer.update(content['model_para']) self.trainer.update_pretrain_task(content['task']) sample_size, model_para_all, model_grads, results = \ self.trainer.train() logger.info( self._monitor.format_eval_res(results, rnd=self.state + 1, role='Client #{}'.format( self.ID), return_raw=True)) self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, content={ 'sample_size': sample_size, 'model_para': model_para_all, 'model_grads': model_grads, })) else: last_contrast_monitor = self._copy_contrast_monitor( content['contrast_monitor']) if last_contrast_monitor.stat == 1: self.trainer.update(content['model_para']) self.trainer.update_contrast_monitor(last_contrast_monitor) sample_size, model_para_all, model_grads, contrast_monitor, \ results = self.trainer.train() if contrast_monitor.stat == 2: self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, content={'contrast_monitor': contrast_monitor})) elif contrast_monitor.stat == 3: logger.info( self._monitor.format_eval_res(results, rnd=self.state + 1, role='Client #{}'.format( self.ID), return_raw=True)) self.comm_manager.send( Message(msg_type='model_para', sender=self.ID, receiver=[sender], state=self.state, content={ 'sample_size': sample_size, 'model_para': model_para_all, 'model_grads': model_grads, 'contrast_monitor': contrast_monitor })) def callback_funcs_for_evaluate(self, message: Message): sender = message.sender self.state = message.state if message.content is not None: self.trainer.update(message.content['model_para']) self.trainer.update_pretrain_task(message.content['task']) if self.early_stopper.early_stopped: metrics = list(self.best_results.values())[0] else: metrics = {} if self._cfg.finetune.before_eval: self.trainer.finetune() for split in self._cfg.eval.split: eval_metrics = self.trainer.evaluate( target_data_split_name=split) if self._cfg.federate.mode == 'distributed': logger.info( self._monitor.format_eval_res(eval_metrics, rnd=self.state + 1, role='Client #{}'.format( self.ID))) metrics.update(**eval_metrics) formatted_eval_res = self._monitor.format_eval_res( metrics, rnd=self.state + 1, role='Client #{}'.format(self.ID), forms='raw', return_raw=True) self.history_results = merge_dict_of_results( self.history_results, formatted_eval_res['Results_raw']) self.comm_manager.send( Message(msg_type='metrics', sender=self.ID, receiver=[sender], state=self.state, content=metrics))