FS-TFP/federatedscope/attack/trainer/PIA_trainer.py

19 lines
612 B
Python

from typing import Type
from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.attack.auxiliary.utils import get_data_property
def wrap_ActivePIATrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
base_trainer.ctx.alpha_prop_loss = base_trainer._cfg.attack.alpha_prop_loss
def hood_on_batch_start_get_prop(ctx):
ctx.prop = get_data_property(ctx.data_batch)
def hook_on_batch_forward_add_PIA_loss(ctx):
ctx.loss_batch = ctx.alpha_prop_loss * ctx.loss_batch + (
1 - ctx.alpha_prop_loss) * ctx.criterion(ctx.y_prob, ctx.prop)