276 lines
12 KiB
Python
276 lines
12 KiB
Python
import torch
|
|
from copy import deepcopy
|
|
|
|
from federatedscope.core.trainers.enums import LIFECYCLE
|
|
from federatedscope.core.trainers.context import CtxVar
|
|
from federatedscope.gfl.loss.vat import VATLoss
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
|
|
|
|
class FLITTrainer(GeneralTorchTrainer):
|
|
def register_default_hooks_train(self):
|
|
super(FLITTrainer, self).register_default_hooks_train()
|
|
self.register_hook_in_train(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(new_hook=record_initialization_global,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(new_hook=del_initialization_global,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def register_default_hooks_eval(self):
|
|
super(FLITTrainer, self).register_default_hooks_eval()
|
|
self.register_hook_in_eval(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_eval(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
self.register_hook_in_eval(new_hook=record_initialization_global,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_eval(new_hook=del_initialization_global,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
batch = ctx.data_batch.to(ctx.device)
|
|
pred = ctx.model(batch)
|
|
ctx.global_model.to(ctx.device)
|
|
predG = ctx.global_model(batch)
|
|
if ctx.criterion._get_name() == 'CrossEntropyLoss':
|
|
label = batch.y.squeeze(-1).long()
|
|
elif ctx.criterion._get_name() == 'MSELoss':
|
|
label = batch.y.float()
|
|
else:
|
|
raise ValueError(
|
|
f'FLIT trainer not support {ctx.criterion._get_name()}.')
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
|
|
lossGlobalLabel = ctx.criterion(predG, label)
|
|
lossLocalLabel = ctx.criterion(pred, label)
|
|
|
|
weightloss = lossLocalLabel + torch.relu(lossLocalLabel -
|
|
lossGlobalLabel.detach())
|
|
if ctx.weight_denomaitor is None:
|
|
ctx.weight_denomaitor = weightloss.mean(dim=0,
|
|
keepdim=True).detach()
|
|
else:
|
|
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
|
|
ctx.weight_denomaitor + (
|
|
-self.cfg.flitplus.factor_ema +
|
|
1) * weightloss.mean(
|
|
keepdim=True, dim=0).detach()
|
|
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
|
|
1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
|
|
ctx.loss_batch = loss.mean()
|
|
|
|
ctx.batch_size = len(label)
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
|
|
|
|
|
class FLITPlusTrainer(FLITTrainer):
|
|
def _hook_on_batch_forward(self, ctx):
|
|
# LDS should be calculated before the forward for cross entropy
|
|
batch = ctx.data_batch.to(ctx.device)
|
|
ctx.global_model.to(ctx.device)
|
|
if ctx.cur_mode == 'test':
|
|
lossLocalVAT, lossGlobalVAT = torch.tensor(0.), torch.tensor(0.)
|
|
else:
|
|
vat_loss = VATLoss() # xi, and eps
|
|
lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
|
|
deepcopy(ctx.criterion))
|
|
lossGlobalVAT = vat_loss(deepcopy(ctx.global_model), batch,
|
|
deepcopy(ctx.criterion))
|
|
|
|
pred = ctx.model(batch)
|
|
predG = ctx.global_model(batch)
|
|
if ctx.criterion._get_name() == 'CrossEntropyLoss':
|
|
label = batch.y.squeeze(-1).long()
|
|
elif ctx.criterion._get_name() == 'MSELoss':
|
|
label = batch.y.float()
|
|
else:
|
|
raise ValueError(
|
|
f'FLITPLUS trainer not support {ctx.criterion._get_name()}.')
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
lossGlobalLabel = ctx.criterion(predG, label)
|
|
lossLocalLabel = ctx.criterion(pred, label)
|
|
|
|
weightloss_loss = lossLocalLabel + torch.relu(lossLocalLabel -
|
|
lossGlobalLabel.detach())
|
|
weightloss_vat = (lossLocalVAT +
|
|
torch.relu(lossLocalVAT - lossGlobalVAT.detach()))
|
|
weightloss = self.cfg.flitplus.lambdavat * \
|
|
weightloss_vat + weightloss_loss
|
|
if ctx.weight_denomaitor is None:
|
|
ctx.weight_denomaitor = weightloss.mean(dim=0,
|
|
keepdim=True).detach()
|
|
else:
|
|
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
|
|
ctx.weight_denomaitor + (
|
|
-self.cfg.flitplus.factor_ema +
|
|
1) * weightloss.mean(
|
|
keepdim=True, dim=0).detach()
|
|
|
|
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
|
|
1e-7)**self.cfg.flitplus.tmpFed * (
|
|
lossLocalLabel +
|
|
self.cfg.flitplus.weightReg * lossLocalVAT)
|
|
ctx.loss_batch = loss.mean()
|
|
|
|
ctx.batch_size = len(label)
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
|
|
|
|
|
class FedFocalTrainer(GeneralTorchTrainer):
|
|
def register_default_hooks_train(self):
|
|
super(FedFocalTrainer, self).register_default_hooks_train()
|
|
self.register_hook_in_train(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def register_default_hooks_eval(self):
|
|
super(FedFocalTrainer, self).register_default_hooks_eval()
|
|
self.register_hook_in_eval(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_eval(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
batch = ctx.data_batch.to(ctx.device)
|
|
pred = ctx.model(batch)
|
|
if ctx.criterion._get_name() == 'CrossEntropyLoss':
|
|
label = batch.y.squeeze(-1).long()
|
|
elif ctx.criterion._get_name() == 'MSELoss':
|
|
label = batch.y.float()
|
|
else:
|
|
raise ValueError(
|
|
f'FLIT trainer not support {ctx.criterion._get_name()}.')
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
|
|
lossLocalLabel = ctx.criterion(pred, label)
|
|
weightloss = lossLocalLabel
|
|
if ctx.weight_denomaitor is None:
|
|
ctx.weight_denomaitor = weightloss.mean(dim=0,
|
|
keepdim=True).detach()
|
|
else:
|
|
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
|
|
ctx.weight_denomaitor + (
|
|
-self.cfg.flitplus.factor_ema +
|
|
1) * weightloss.mean(
|
|
keepdim=True, dim=0).detach()
|
|
|
|
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
|
|
1e-7)**self.cfg.flitplus.tmpFed * (lossLocalLabel)
|
|
ctx.loss_batch = loss.mean()
|
|
|
|
ctx.batch_size = len(label)
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
|
|
|
|
|
class FedVATTrainer(GeneralTorchTrainer):
|
|
def register_default_hooks_train(self):
|
|
super(FedVATTrainer, self).register_default_hooks_train()
|
|
self.register_hook_in_train(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def register_default_hooks_eval(self):
|
|
super(FedVATTrainer, self).register_default_hooks_eval()
|
|
self.register_hook_in_eval(new_hook=record_initialization_local,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
self.register_hook_in_eval(new_hook=del_initialization_local,
|
|
trigger='on_fit_end',
|
|
insert_pos=-1)
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
batch = ctx.data_batch.to(ctx.device)
|
|
if ctx.cur_mode == 'test':
|
|
lossLocalVAT = torch.tensor(0.)
|
|
else:
|
|
vat_loss = VATLoss() # xi, and eps
|
|
lossLocalVAT = vat_loss(deepcopy(ctx.model), batch,
|
|
deepcopy(ctx.criterion))
|
|
|
|
pred = ctx.model(batch)
|
|
if ctx.criterion._get_name() == 'CrossEntropyLoss':
|
|
label = batch.y.squeeze(-1).long()
|
|
elif ctx.criterion._get_name() == 'MSELoss':
|
|
label = batch.y.float()
|
|
else:
|
|
raise ValueError(
|
|
f'FedVAT trainer not support {ctx.criterion._get_name()}.')
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
lossLocalLabel = ctx.criterion(pred, label)
|
|
weightloss = lossLocalLabel + self.cfg.flitplus.lambdavat * \
|
|
lossLocalVAT
|
|
if ctx.weight_denomaitor is None:
|
|
ctx.weight_denomaitor = weightloss.mean(dim=0,
|
|
keepdim=True).detach()
|
|
else:
|
|
ctx.weight_denomaitor = self.cfg.flitplus.factor_ema * \
|
|
ctx.weight_denomaitor + (
|
|
-self.cfg.flitplus.factor_ema +
|
|
1) * weightloss.mean(
|
|
keepdim=True, dim=0).detach()
|
|
|
|
loss = (1 - torch.exp(-weightloss / (ctx.weight_denomaitor + 1e-7)) +
|
|
1e-7)**self.cfg.flitplus.tmpFed * (
|
|
lossLocalLabel +
|
|
self.cfg.flitplus.weightReg * lossLocalVAT)
|
|
ctx.loss_batch = loss.mean()
|
|
|
|
ctx.batch_size = len(label)
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
|
|
|
|
|
def record_initialization_local(ctx):
|
|
"""Record weight denomaitor to cpu
|
|
|
|
"""
|
|
ctx.weight_denomaitor = None
|
|
|
|
|
|
def del_initialization_local(ctx):
|
|
"""Clear the variable to avoid memory leakage
|
|
|
|
"""
|
|
ctx.weight_denomaitor = None
|
|
|
|
|
|
def record_initialization_global(ctx):
|
|
"""Record the shared global model to cpu
|
|
|
|
"""
|
|
ctx.global_model = deepcopy(ctx.model)
|
|
ctx.global_model.to(torch.device("cpu"))
|
|
|
|
|
|
def del_initialization_global(ctx):
|
|
"""Clear the variable to avoid memory leakage
|
|
|
|
"""
|
|
ctx.global_model = None
|