152 lines
5.8 KiB
Python
152 lines
5.8 KiB
Python
import logging
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
|
|
from federatedscope.core.workers import Client
|
|
from federatedscope.nlp.hetero_tasks.trainer.utils import ContrastiveMonitor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ATCClient(Client):
|
|
def __init__(self,
|
|
ID=-1,
|
|
server_id=None,
|
|
state=-1,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
device='cpu',
|
|
strategy=None,
|
|
*args,
|
|
**kwargs):
|
|
|
|
super().__init__(
|
|
ID=ID,
|
|
server_id=server_id,
|
|
state=state,
|
|
config=config,
|
|
data=data,
|
|
model=model,
|
|
device=device,
|
|
strategy=strategy,
|
|
*args,
|
|
**kwargs,
|
|
)
|
|
|
|
self.use_contrastive_loss = self._cfg.model.use_contrastive_loss
|
|
self.trainer.update_stat(self.ID)
|
|
|
|
def _copy_contrast_monitor(self, raw_monitor):
|
|
monitor = ContrastiveMonitor()
|
|
for var in vars(monitor):
|
|
getattr(monitor,
|
|
'update_{}'.format(var))(getattr(raw_monitor, var))
|
|
return monitor
|
|
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.state = round
|
|
|
|
if not self.use_contrastive_loss:
|
|
self.trainer.update(content['model_para'])
|
|
self.trainer.update_pretrain_task(content['task'])
|
|
|
|
sample_size, model_para_all, model_grads, results = \
|
|
self.trainer.train()
|
|
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state + 1,
|
|
role='Client #{}'.format(
|
|
self.ID),
|
|
return_raw=True))
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content={
|
|
'sample_size': sample_size,
|
|
'model_para': model_para_all,
|
|
'model_grads': model_grads,
|
|
}))
|
|
|
|
else:
|
|
last_contrast_monitor = self._copy_contrast_monitor(
|
|
content['contrast_monitor'])
|
|
if last_contrast_monitor.stat == 1:
|
|
self.trainer.update(content['model_para'])
|
|
self.trainer.update_contrast_monitor(last_contrast_monitor)
|
|
|
|
sample_size, model_para_all, model_grads, contrast_monitor, \
|
|
results = self.trainer.train()
|
|
|
|
if contrast_monitor.stat == 2:
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content={'contrast_monitor': contrast_monitor}))
|
|
|
|
elif contrast_monitor.stat == 3:
|
|
logger.info(
|
|
self._monitor.format_eval_res(results,
|
|
rnd=self.state + 1,
|
|
role='Client #{}'.format(
|
|
self.ID),
|
|
return_raw=True))
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content={
|
|
'sample_size': sample_size,
|
|
'model_para': model_para_all,
|
|
'model_grads': model_grads,
|
|
'contrast_monitor': contrast_monitor
|
|
}))
|
|
|
|
def callback_funcs_for_evaluate(self, message: Message):
|
|
sender = message.sender
|
|
self.state = message.state
|
|
if message.content is not None:
|
|
self.trainer.update(message.content['model_para'])
|
|
self.trainer.update_pretrain_task(message.content['task'])
|
|
if self.early_stopper.early_stopped:
|
|
metrics = list(self.best_results.values())[0]
|
|
else:
|
|
metrics = {}
|
|
if self._cfg.finetune.before_eval:
|
|
self.trainer.finetune()
|
|
for split in self._cfg.eval.split:
|
|
eval_metrics = self.trainer.evaluate(
|
|
target_data_split_name=split)
|
|
|
|
if self._cfg.federate.mode == 'distributed':
|
|
logger.info(
|
|
self._monitor.format_eval_res(eval_metrics,
|
|
rnd=self.state + 1,
|
|
role='Client #{}'.format(
|
|
self.ID)))
|
|
|
|
metrics.update(**eval_metrics)
|
|
|
|
formatted_eval_res = self._monitor.format_eval_res(
|
|
metrics,
|
|
rnd=self.state + 1,
|
|
role='Client #{}'.format(self.ID),
|
|
forms='raw',
|
|
return_raw=True)
|
|
self.history_results = merge_dict_of_results(
|
|
self.history_results, formatted_eval_res['Results_raw'])
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='metrics',
|
|
sender=self.ID,
|
|
receiver=[sender],
|
|
state=self.state,
|
|
content=metrics))
|