112 lines
3.8 KiB
Python
112 lines
3.8 KiB
Python
import torch
|
|
from federatedscope.register import register_trainer
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
from federatedscope.core.trainers.context import CtxVar
|
|
from federatedscope.core.trainers.enums import LIFECYCLE, MODE
|
|
from federatedscope.core.trainers.utils import move_to
|
|
|
|
|
|
class CLTrainer(GeneralTorchTrainer):
|
|
def __init__(self,
|
|
model,
|
|
data,
|
|
device,
|
|
config,
|
|
only_for_eval=False,
|
|
monitor=None):
|
|
super(CLTrainer, self).__init__(model, data, device, config,
|
|
only_for_eval, monitor)
|
|
self.batches_aug_data_1, self.batches_aug_data_2 = torch.empty(
|
|
1), torch.empty(1)
|
|
self.z1, self.z2 = torch.empty(1), torch.empty(1)
|
|
self.num_samples = 0
|
|
self.local_loss_ratio = 1
|
|
self.global_loss_ratio = 5
|
|
|
|
def get_train_pred_embedding(self):
|
|
model = self.ctx.model.to(self.ctx.device)
|
|
x1, x2 = self.batches_aug_data_1.to(
|
|
self.ctx.device), self.batches_aug_data_2.to(self.ctx.device)
|
|
z1, z2 = model(x1, x2)
|
|
self.batches_aug_data_1, self.batches_aug_data_2 = [], []
|
|
self.z1, self.z2 = z1, z2
|
|
self.ctx.model.to(torch.device('cpu'))
|
|
|
|
return [self.z1, self.z2]
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
x, label = [move_to(_, ctx.device) for _ in ctx.data_batch]
|
|
x1, x2 = x[0], x[1]
|
|
if ctx.cur_mode in [MODE.TRAIN]:
|
|
self.batches_aug_data_1 = x1
|
|
self.batches_aug_data_2 = x2
|
|
z1, z2 = ctx.model(x1, x2)
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar((z1, z2), LIFECYCLE.BATCH)
|
|
ctx.loss_batch = CtxVar(ctx.criterion(z1, z2), LIFECYCLE.BATCH)
|
|
ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)
|
|
|
|
def _hook_on_batch_backward(self, ctx):
|
|
ctx.optimizer.zero_grad()
|
|
ctx.loss_task = ctx.loss_task * self.local_loss_ratio
|
|
ctx.loss_task.backward()
|
|
if ctx.grad_clip > 0:
|
|
torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
|
|
ctx.grad_clip)
|
|
|
|
ctx.optimizer.step()
|
|
if ctx.scheduler is not None:
|
|
ctx.scheduler.step()
|
|
|
|
def _hook_on_batch_end(self, ctx):
|
|
# update statistics
|
|
ctx.num_samples += ctx.batch_size
|
|
if ctx.cur_mode in [MODE.TRAIN]:
|
|
self.num_samples = ctx.num_samples
|
|
ctx.loss_batch_total += ctx.loss_batch.item() * ctx.batch_size
|
|
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))
|
|
# cache label for evaluate
|
|
ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
|
|
ctx.ys_prob.append(ctx.y_prob[0].detach().cpu().numpy())
|
|
|
|
def train_with_global_loss(self, loss):
|
|
|
|
self.ctx.model = self.ctx.model.to(self.ctx.device)
|
|
|
|
loss = loss.requires_grad_() * self.global_loss_ratio
|
|
loss.backward()
|
|
|
|
self.ctx.optimizer.step()
|
|
|
|
return self.ctx.model.state_dict()
|
|
|
|
|
|
class LPTrainer(GeneralTorchTrainer):
|
|
def __init__(self,
|
|
model,
|
|
data,
|
|
device,
|
|
config,
|
|
only_for_eval=False,
|
|
monitor=None):
|
|
super(LPTrainer, self).__init__(model, data, device, config,
|
|
only_for_eval, monitor)
|
|
|
|
if config.federate.restore_from != '':
|
|
self.load_model(config.federate.restore_from)
|
|
|
|
|
|
def call_cl_trainer(trainer_type):
|
|
if trainer_type == 'cltrainer':
|
|
trainer_builder = CLTrainer
|
|
return trainer_builder
|
|
elif trainer_type == 'lptrainer':
|
|
trainer_builder = LPTrainer
|
|
return trainer_builder
|
|
|
|
|
|
register_trainer('cltrainer', call_cl_trainer)
|