FS-TFP/federatedscope/attack/trainer/backdoor_trainer.py

181 lines
5.9 KiB
Python

import logging
from typing import Type
import torch
import numpy as np
import copy
from federatedscope.core.trainers import GeneralTorchTrainer
from torch.nn.utils import parameters_to_vector, vector_to_parameters
logger = logging.getLogger(__name__)
def wrap_backdoorTrainer(
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
# ---------------- attribute-level plug-in -----------------------
base_trainer.ctx.target_label_ind \
= base_trainer.cfg.attack.target_label_ind
base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type
base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type
# ---- action-level plug-in -------
if base_trainer.cfg.attack.self_opt:
base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr
base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_start_init_local_opt,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
trigger='on_fit_end',
insert_pos=0)
scale_poisoning = base_trainer.cfg.attack.scale_poisoning
pgd_poisoning = base_trainer.cfg.attack.pgd_poisoning
if scale_poisoning or pgd_poisoning:
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_start_init_local_model,
trigger='on_fit_start',
insert_pos=-1)
if base_trainer.cfg.attack.scale_poisoning:
base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_end_scale_poisoning,
trigger="on_fit_end",
insert_pos=-1)
if base_trainer.cfg.attack.pgd_poisoning:
base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr
base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps
base_trainer.ctx.batch_index = 0
base_trainer.register_hook_in_train(
new_hook=hook_on_fit_start_init_local_pgd,
trigger='on_fit_start',
insert_pos=-1)
base_trainer.register_hook_in_train(
new_hook=hook_on_batch_end_project_grad,
trigger='on_batch_end',
insert_pos=-1)
base_trainer.register_hook_in_train(
new_hook=hook_on_epoch_end_project_grad,
trigger='on_epoch_end',
insert_pos=-1)
base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
trigger='on_fit_end',
insert_pos=0)
return base_trainer
def hook_on_fit_start_init_local_opt(ctx):
ctx.original_epoch = ctx["num_train_epoch"]
ctx["num_train_epoch"] = ctx.self_epoch
def hook_on_fit_end_reset_opt(ctx):
ctx["num_train_epoch"] = ctx.original_epoch
def hook_on_fit_start_init_local_model(ctx):
# the original global model
ctx.original_model = copy.deepcopy(ctx.model)
def hook_on_fit_end_scale_poisoning(ctx):
# conduct the scale poisoning
scale_para = ctx.scale_para
v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters())
logger.info("the Norm of the original global model: {}".format(
torch.norm(v).item()))
v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
logger.info("Attacker before scaling : Norm = {}".format(
torch.norm(v).item()))
ctx.original_model = list(ctx.original_model.parameters())
for idx, param in enumerate(ctx.model.parameters()):
param.data = (param.data - ctx.original_model[idx]
) * scale_para + ctx.original_model[idx]
v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
logger.info("Attacker after scaling : Norm = {}".format(
torch.norm(v).item()))
logger.info('finishing model scaling poisoning attack')
def hook_on_fit_start_init_local_pgd(ctx):
ctx.original_optimizer = ctx.optimizer
ctx.original_epoch = ctx["num_train_epoch"]
ctx["num_train_epoch"] = ctx.self_epoch
ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), lr=ctx.pgd_lr)
# looks like adversary needs same lr to hide with others
def hook_on_batch_end_project_grad(ctx):
'''
after every 10 iters, we project update on the predefined norm ball.
'''
eps = ctx.pgd_eps
project_frequency = 10
ctx.batch_index += 1
w = list(ctx.model.parameters())
w_vec = parameters_to_vector(w)
model_original_vec = parameters_to_vector(
list(ctx.original_model.parameters()))
# make sure you project on last iteration otherwise,
# high LR pushes you really far
if (ctx.batch_index % project_frequency
== 0) and (torch.norm(w_vec - model_original_vec) > eps):
# project back into norm ball
w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
w_vec - model_original_vec) + model_original_vec
# plug w_proj back into model
vector_to_parameters(w_proj_vec, w)
def hook_on_epoch_end_project_grad(ctx):
'''
after the whole epoch, we project the update on the predefined norm ball.
'''
ctx.batch_index = 0
eps = ctx.pgd_eps
w = list(ctx.model.parameters())
w_vec = parameters_to_vector(w)
model_original_vec = parameters_to_vector(
list(ctx.original_model.parameters()))
if (torch.norm(w_vec - model_original_vec) > eps):
# project back into norm ball
w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
w_vec - model_original_vec) + model_original_vec
# plug w_proj back into model
vector_to_parameters(w_proj_vec, w)
def hook_on_fit_end_reset_pgd(ctx):
ctx.optimizer = ctx.original_optimizer