54 lines
1.6 KiB
Python
54 lines
1.6 KiB
Python
import logging
|
|
from typing import Type
|
|
|
|
import torch
|
|
|
|
from federatedscope.core.trainers import GeneralTorchTrainer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def wrap_GaussianAttackTrainer(
|
|
base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]:
|
|
'''
|
|
wrap the gaussian attack trainer
|
|
|
|
Args:
|
|
base_trainer: Type: core.trainers.GeneralTorchTrainer
|
|
:returns:
|
|
The wrapped trainer; Type: core.trainers.GeneralTorchTrainer
|
|
'''
|
|
|
|
base_trainer.replace_hook_in_train(
|
|
new_hook=hook_on_batch_backward_generate_gaussian_noise_gradient,
|
|
target_trigger='on_batch_backward',
|
|
target_hook_name='_hook_on_batch_backward')
|
|
|
|
return base_trainer
|
|
|
|
|
|
def hook_on_batch_backward_generate_gaussian_noise_gradient(ctx):
|
|
ctx.optimizer.zero_grad()
|
|
ctx.loss_task.backward()
|
|
|
|
if ctx.grad_clip > 0:
|
|
torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip)
|
|
|
|
grad_values = list()
|
|
for name, param in ctx.model.named_parameters():
|
|
if 'bn' not in name:
|
|
grad_values.append(param.grad.detach().cpu().view(-1))
|
|
|
|
grad_values = torch.cat(grad_values)
|
|
mean_for_gaussian_noise = torch.mean(grad_values) + 0.1
|
|
std_for_gaussian_noise = torch.std(grad_values)
|
|
|
|
for name, param in ctx.model.named_parameters():
|
|
if 'bn' not in name:
|
|
generated_grad = torch.normal(mean=mean_for_gaussian_noise,
|
|
std=std_for_gaussian_noise,
|
|
size=param.grad.shape)
|
|
param.grad = generated_grad.to(param.grad.device)
|
|
|
|
ctx.optimizer.step()
|