181 lines
5.9 KiB
Python
181 lines
5.9 KiB
Python
import logging
|
|
from typing import Type
|
|
import torch
|
|
import numpy as np
|
|
import copy
|
|
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
from torch.nn.utils import parameters_to_vector, vector_to_parameters
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def wrap_backdoorTrainer(
|
|
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
|
|
|
|
# ---------------- attribute-level plug-in -----------------------
|
|
base_trainer.ctx.target_label_ind \
|
|
= base_trainer.cfg.attack.target_label_ind
|
|
base_trainer.ctx.trigger_type = base_trainer.cfg.attack.trigger_type
|
|
base_trainer.ctx.label_type = base_trainer.cfg.attack.label_type
|
|
|
|
# ---- action-level plug-in -------
|
|
|
|
if base_trainer.cfg.attack.self_opt:
|
|
|
|
base_trainer.ctx.self_lr = base_trainer.cfg.attack.self_lr
|
|
base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_fit_start_init_local_opt,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
|
|
trigger='on_fit_end',
|
|
insert_pos=0)
|
|
|
|
scale_poisoning = base_trainer.cfg.attack.scale_poisoning
|
|
pgd_poisoning = base_trainer.cfg.attack.pgd_poisoning
|
|
|
|
if scale_poisoning or pgd_poisoning:
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_fit_start_init_local_model,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
|
|
if base_trainer.cfg.attack.scale_poisoning:
|
|
|
|
base_trainer.ctx.scale_para = base_trainer.cfg.attack.scale_para
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_fit_end_scale_poisoning,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
|
|
if base_trainer.cfg.attack.pgd_poisoning:
|
|
|
|
base_trainer.ctx.self_epoch = base_trainer.cfg.attack.self_epoch
|
|
base_trainer.ctx.pgd_lr = base_trainer.cfg.attack.pgd_lr
|
|
base_trainer.ctx.pgd_eps = base_trainer.cfg.attack.pgd_eps
|
|
base_trainer.ctx.batch_index = 0
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_fit_start_init_local_pgd,
|
|
trigger='on_fit_start',
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_batch_end_project_grad,
|
|
trigger='on_batch_end',
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(
|
|
new_hook=hook_on_epoch_end_project_grad,
|
|
trigger='on_epoch_end',
|
|
insert_pos=-1)
|
|
|
|
base_trainer.register_hook_in_train(new_hook=hook_on_fit_end_reset_opt,
|
|
trigger='on_fit_end',
|
|
insert_pos=0)
|
|
|
|
return base_trainer
|
|
|
|
|
|
def hook_on_fit_start_init_local_opt(ctx):
|
|
|
|
ctx.original_epoch = ctx["num_train_epoch"]
|
|
ctx["num_train_epoch"] = ctx.self_epoch
|
|
|
|
|
|
def hook_on_fit_end_reset_opt(ctx):
|
|
|
|
ctx["num_train_epoch"] = ctx.original_epoch
|
|
|
|
|
|
def hook_on_fit_start_init_local_model(ctx):
|
|
|
|
# the original global model
|
|
ctx.original_model = copy.deepcopy(ctx.model)
|
|
|
|
|
|
def hook_on_fit_end_scale_poisoning(ctx):
|
|
|
|
# conduct the scale poisoning
|
|
scale_para = ctx.scale_para
|
|
|
|
v = torch.nn.utils.parameters_to_vector(ctx.original_model.parameters())
|
|
logger.info("the Norm of the original global model: {}".format(
|
|
torch.norm(v).item()))
|
|
|
|
v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
|
|
logger.info("Attacker before scaling : Norm = {}".format(
|
|
torch.norm(v).item()))
|
|
|
|
ctx.original_model = list(ctx.original_model.parameters())
|
|
|
|
for idx, param in enumerate(ctx.model.parameters()):
|
|
param.data = (param.data - ctx.original_model[idx]
|
|
) * scale_para + ctx.original_model[idx]
|
|
|
|
v = torch.nn.utils.parameters_to_vector(ctx.model.parameters())
|
|
logger.info("Attacker after scaling : Norm = {}".format(
|
|
torch.norm(v).item()))
|
|
|
|
logger.info('finishing model scaling poisoning attack')
|
|
|
|
|
|
def hook_on_fit_start_init_local_pgd(ctx):
|
|
|
|
ctx.original_optimizer = ctx.optimizer
|
|
ctx.original_epoch = ctx["num_train_epoch"]
|
|
ctx["num_train_epoch"] = ctx.self_epoch
|
|
ctx.optimizer = torch.optim.SGD(ctx.model.parameters(), lr=ctx.pgd_lr)
|
|
# looks like adversary needs same lr to hide with others
|
|
|
|
|
|
def hook_on_batch_end_project_grad(ctx):
|
|
'''
|
|
after every 10 iters, we project update on the predefined norm ball.
|
|
'''
|
|
eps = ctx.pgd_eps
|
|
project_frequency = 10
|
|
ctx.batch_index += 1
|
|
w = list(ctx.model.parameters())
|
|
w_vec = parameters_to_vector(w)
|
|
model_original_vec = parameters_to_vector(
|
|
list(ctx.original_model.parameters()))
|
|
# make sure you project on last iteration otherwise,
|
|
# high LR pushes you really far
|
|
if (ctx.batch_index % project_frequency
|
|
== 0) and (torch.norm(w_vec - model_original_vec) > eps):
|
|
# project back into norm ball
|
|
w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
|
|
w_vec - model_original_vec) + model_original_vec
|
|
# plug w_proj back into model
|
|
vector_to_parameters(w_proj_vec, w)
|
|
|
|
|
|
def hook_on_epoch_end_project_grad(ctx):
|
|
'''
|
|
after the whole epoch, we project the update on the predefined norm ball.
|
|
'''
|
|
ctx.batch_index = 0
|
|
eps = ctx.pgd_eps
|
|
w = list(ctx.model.parameters())
|
|
w_vec = parameters_to_vector(w)
|
|
model_original_vec = parameters_to_vector(
|
|
list(ctx.original_model.parameters()))
|
|
if (torch.norm(w_vec - model_original_vec) > eps):
|
|
# project back into norm ball
|
|
w_proj_vec = eps * (w_vec - model_original_vec) / torch.norm(
|
|
w_vec - model_original_vec) + model_original_vec
|
|
# plug w_proj back into model
|
|
vector_to_parameters(w_proj_vec, w)
|
|
|
|
|
|
def hook_on_fit_end_reset_pgd(ctx):
|
|
|
|
ctx.optimizer = ctx.original_optimizer
|