FS-TFP/federatedscope/gfl/flitplus/trainer.py

276 lines
12 KiB
Python

import torch
from copy import deepcopy
from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.trainers.context import CtxVar
from federatedscope.gfl.loss.vat import VATLoss
from federatedscope.core.trainers import GeneralTorchTrainer
class FLITTrainer(GeneralTorchTrainer):
def register_default_hooks_train(self):
super(FLITTrainer, self).register_default_hooks_train()
self.register_hook_in_train(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_train(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
self.register_hook_in_train(new_hook=record_initialization_global,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_train(new_hook=del_initialization_global,
trigger='on_fit_end',
insert_pos=-1)
def register_default_hooks_eval(self):
super(FLITTrainer, self).register_default_hooks_eval()
self.register_hook_in_eval(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_eval(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
self.register_hook_in_eval(new_hook=record_initialization_global,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_eval(new_hook=del_initialization_global,
trigger='on_fit_end',
insert_pos=-1)
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
pred = ctx.model(batch)
ctx.global_model.to(ctx.device)
predG = ctx.global_model(batch)
if ctx.criterion._get_name() == 'CrossEntropyLoss':
label = batch.y.squeeze(-1).long()
elif ctx.criterion._get_name() == 'MSELoss':
label = batch.y.float()
else:
raise ValueError(
f'FLIT trainer not support {ctx.criterion._get_name()}.')
if len(label.size()) == 0:
label = label.unsqueeze(0)
lossGlobalLabel = ctx.criterion(predG, label)
lossLocalLabel = ctx.criterion(pred, label)
weightloss = lossLocalLabel + torch.relu(lossLocalLabel -
lossGlobalLabel.detach())
if ctx.weight_denomaitor is None:
ctx.weight_denomaitor = weightloss.mean(dim=0,
keepdim=True).detach()
else:
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
ctx.weight_denomaitor + (
-self.cfg.flitplus.factor_ema +
1) * weightloss.mean(
keepdim=True, dim=0).detach()
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
ctx.loss_batch = loss.mean()
ctx.batch_size = len(label)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
class FLITPlusTrainer(FLITTrainer):
def _hook_on_batch_forward(self, ctx):
# LDS should be calculated before the forward for cross entropy
batch = ctx.data_batch.to(ctx.device)
ctx.global_model.to(ctx.device)
if ctx.cur_mode == 'test':
lossLocalVAT, lossGlobalVAT = torch.tensor(0.), torch.tensor(0.)
else:
vat_loss = VATLoss() # xi, and eps
lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
deepcopy(ctx.criterion))
lossGlobalVAT = vat_loss(deepcopy(ctx.global_model), batch,
deepcopy(ctx.criterion))
pred = ctx.model(batch)
predG = ctx.global_model(batch)
if ctx.criterion._get_name() == 'CrossEntropyLoss':
label = batch.y.squeeze(-1).long()
elif ctx.criterion._get_name() == 'MSELoss':
label = batch.y.float()
else:
raise ValueError(
f'FLITPLUS trainer not support {ctx.criterion._get_name()}.')
if len(label.size()) == 0:
label = label.unsqueeze(0)
lossGlobalLabel = ctx.criterion(predG, label)
lossLocalLabel = ctx.criterion(pred, label)
weightloss_loss = lossLocalLabel + torch.relu(lossLocalLabel -
lossGlobalLabel.detach())
weightloss_vat = (lossLocalVAT +
torch.relu(lossLocalVAT - lossGlobalVAT.detach()))
weightloss = self.cfg.flitplus.lambdavat * \
weightloss_vat + weightloss_loss
if ctx.weight_denomaitor is None:
ctx.weight_denomaitor = weightloss.mean(dim=0,
keepdim=True).detach()
else:
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
ctx.weight_denomaitor + (
-self.cfg.flitplus.factor_ema +
1) * weightloss.mean(
keepdim=True, dim=0).detach()
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
1e-7)**self.cfg.flitplus.tmpFed * (
lossLocalLabel +
self.cfg.flitplus.weightReg * lossLocalVAT)
ctx.loss_batch = loss.mean()
ctx.batch_size = len(label)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
class FedFocalTrainer(GeneralTorchTrainer):
def register_default_hooks_train(self):
super(FedFocalTrainer, self).register_default_hooks_train()
self.register_hook_in_train(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_train(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
def register_default_hooks_eval(self):
super(FedFocalTrainer, self).register_default_hooks_eval()
self.register_hook_in_eval(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_eval(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
pred = ctx.model(batch)
if ctx.criterion._get_name() == 'CrossEntropyLoss':
label = batch.y.squeeze(-1).long()
elif ctx.criterion._get_name() == 'MSELoss':
label = batch.y.float()
else:
raise ValueError(
f'FLIT trainer not support {ctx.criterion._get_name()}.')
if len(label.size()) == 0:
label = label.unsqueeze(0)
lossLocalLabel = ctx.criterion(pred, label)
weightloss = lossLocalLabel
if ctx.weight_denomaitor is None:
ctx.weight_denomaitor = weightloss.mean(dim=0,
keepdim=True).detach()
else:
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
ctx.weight_denomaitor + (
-self.cfg.flitplus.factor_ema +
1) * weightloss.mean(
keepdim=True, dim=0).detach()
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
ctx.loss_batch = loss.mean()
ctx.batch_size = len(label)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
class FedVATTrainer(GeneralTorchTrainer):
def register_default_hooks_train(self):
super(FedVATTrainer, self).register_default_hooks_train()
self.register_hook_in_train(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_train(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
def register_default_hooks_eval(self):
super(FedVATTrainer, self).register_default_hooks_eval()
self.register_hook_in_eval(new_hook=record_initialization_local,
trigger='on_fit_start',
insert_pos=-1)
self.register_hook_in_eval(new_hook=del_initialization_local,
trigger='on_fit_end',
insert_pos=-1)
def _hook_on_batch_forward(self, ctx):
batch = ctx.data_batch.to(ctx.device)
if ctx.cur_mode == 'test':
lossLocalVAT = torch.tensor(0.)
else:
vat_loss = VATLoss() # xi, and eps
lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
deepcopy(ctx.criterion))
pred = ctx.model(batch)
if ctx.criterion._get_name() == 'CrossEntropyLoss':
label = batch.y.squeeze(-1).long()
elif ctx.criterion._get_name() == 'MSELoss':
label = batch.y.float()
else:
raise ValueError(
f'FedVAT trainer not support {ctx.criterion._get_name()}.')
if len(label.size()) == 0:
label = label.unsqueeze(0)
lossLocalLabel = ctx.criterion(pred, label)
weightloss = lossLocalLabel + self.cfg.flitplus.lambdavat * \
lossLocalVAT
if ctx.weight_denomaitor is None:
ctx.weight_denomaitor = weightloss.mean(dim=0,
keepdim=True).detach()
else:
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
ctx.weight_denomaitor + (
-self.cfg.flitplus.factor_ema +
1) * weightloss.mean(
keepdim=True, dim=0).detach()
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
1e-7)**self.cfg.flitplus.tmpFed * (
lossLocalLabel +
self.cfg.flitplus.weightReg * lossLocalVAT)
ctx.loss_batch = loss.mean()
ctx.batch_size = len(label)
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
def record_initialization_local(ctx):
"""Record weight denomaitor to cpu
"""
ctx.weight_denomaitor = None
def del_initialization_local(ctx):
"""Clear the variable to avoid memory leakage
"""
ctx.weight_denomaitor = None
def record_initialization_global(ctx):
"""Record the shared global model to cpu
"""
ctx.global_model = deepcopy(ctx.model)
ctx.global_model.to(torch.device("cpu"))
def del_initialization_global(ctx):
"""Clear the variable to avoid memory leakage
"""
ctx.global_model = None