FS-TFP/federatedscope/nlp/trainer/trainer.py

32 lines
904 B
Python

from federatedscope.register import register_trainer
from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.core.trainers.utils import move_to
class NLPTrainer(GeneralTorchTrainer):
"""
``NLPTrainer`` is used for text data.
"""
def _hook_on_batch_forward(self, ctx):
x, label = [move_to(_, ctx.device) for _ in ctx.data_batch]
if isinstance(x, dict):
pred = ctx.model(**x)[0]
else:
pred = ctx.model(x)
if len(label.size()) == 0:
label = label.unsqueeze(0)
ctx.loss_batch = ctx.criterion(pred, label)
ctx.y_true = label
ctx.y_prob = pred
ctx.batch_size = len(label)
def call_nlp_trainer(trainer_type):
if trainer_type == 'nlptrainer':
trainer_builder = NLPTrainer
return trainer_builder
register_trainer('nlptrainer', call_nlp_trainer)