FS-TFP/federatedscope/nlp/hetero_tasks/worker/client.py

152 lines
5.8 KiB
Python

import logging
from federatedscope.core.message import Message
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
from federatedscope.core.workers import Client
from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor
logger = logging.getLogger(__name__)
class ATCClient(Client):
def __init__(self,
ID=-1,
server_id=None,
state=-1,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
*args,
**kwargs):
super().__init__(
ID=ID,
server_id=server_id,
state=state,
config=config,
data=data,
model=model,
device=device,
strategy=strategy,
*args,
**kwargs,
)
self.use_contrastive_loss = self._cfg.model.use_contrastive_loss
self.trainer.update_stat(self.ID)
def _copy_contrast_monitor(self, raw_monitor):
monitor = ContrastiveMonitor()
for var in vars(monitor):
getattr(monitor,
'update_{}'.format(var))(getattr(raw_monitor, var))
return monitor
def callback_funcs_for_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
self.state = round
if not self.use_contrastive_loss:
self.trainer.update(content['model_para'])
self.trainer.update_pretrain_task(content['task'])
sample_size, model_para_all, model_grads, results = \
self.trainer.train()
logger.info(
self._monitor.format_eval_res(results,
rnd=self.state + 1,
role='Client #{}'.format(
self.ID),
return_raw=True))
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content={
'sample_size': sample_size,
'model_para': model_para_all,
'model_grads': model_grads,
}))
else:
last_contrast_monitor = self._copy_contrast_monitor(
content['contrast_monitor'])
if last_contrast_monitor.stat == 1:
self.trainer.update(content['model_para'])
self.trainer.update_contrast_monitor(last_contrast_monitor)
sample_size, model_para_all, model_grads, contrast_monitor, \
results = self.trainer.train()
if contrast_monitor.stat == 2:
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content={'contrast_monitor': contrast_monitor}))
elif contrast_monitor.stat == 3:
logger.info(
self._monitor.format_eval_res(results,
rnd=self.state + 1,
role='Client #{}'.format(
self.ID),
return_raw=True))
self.comm_manager.send(
Message(msg_type='model_para',
sender=self.ID,
receiver=[sender],
state=self.state,
content={
'sample_size': sample_size,
'model_para': model_para_all,
'model_grads': model_grads,
'contrast_monitor': contrast_monitor
}))
def callback_funcs_for_evaluate(self, message: Message):
sender = message.sender
self.state = message.state
if message.content is not None:
self.trainer.update(message.content['model_para'])
self.trainer.update_pretrain_task(message.content['task'])
if self.early_stopper.early_stopped:
metrics = list(self.best_results.values())[0]
else:
metrics = {}
if self._cfg.finetune.before_eval:
self.trainer.finetune()
for split in self._cfg.eval.split:
eval_metrics = self.trainer.evaluate(
target_data_split_name=split)
if self._cfg.federate.mode == 'distributed':
logger.info(
self._monitor.format_eval_res(eval_metrics,
rnd=self.state + 1,
role='Client #{}'.format(
self.ID)))
metrics.update(**eval_metrics)
formatted_eval_res = self._monitor.format_eval_res(
metrics,
rnd=self.state + 1,
role='Client #{}'.format(self.ID),
forms='raw',
return_raw=True)
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res['Results_raw'])
self.comm_manager.send(
Message(msg_type='metrics',
sender=self.ID,
receiver=[sender],
state=self.state,
content=metrics))