32 lines
904 B
Python
32 lines
904 B
Python
from federatedscope.register import register_trainer
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
from federatedscope.core.trainers.utils import move_to
|
|
|
|
|
|
class NLPTrainer(GeneralTorchTrainer):
|
|
"""
|
|
``NLPTrainer`` is used for text data.
|
|
"""
|
|
def _hook_on_batch_forward(self, ctx):
|
|
x, label = [move_to(_, ctx.device) for _ in ctx.data_batch]
|
|
if isinstance(x, dict):
|
|
pred = ctx.model(**x)[0]
|
|
else:
|
|
pred = ctx.model(x)
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
ctx.loss_batch = ctx.criterion(pred, label)
|
|
ctx.y_true = label
|
|
ctx.y_prob = pred
|
|
|
|
ctx.batch_size = len(label)
|
|
|
|
|
|
def call_nlp_trainer(trainer_type):
|
|
if trainer_type == 'nlptrainer':
|
|
trainer_builder = NLPTrainer
|
|
return trainer_builder
|
|
|
|
|
|
register_trainer('nlptrainer', call_nlp_trainer)
|