377 lines
16 KiB
Python
377 lines
16 KiB
Python
from federatedscope.core.workers import Server
|
|
from federatedscope.core.message import Message
|
|
|
|
from federatedscope.core.auxiliaries.criterion_builder import get_criterion
|
|
import copy
|
|
from federatedscope.attack.auxiliary.utils import get_data_sav_fn, \
|
|
get_reconstructor
|
|
|
|
import logging
|
|
|
|
import torch
|
|
import numpy as np
|
|
from federatedscope.attack.privacy_attacks.passive_PIA import \
|
|
PassivePropertyInference
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class BackdoorServer(Server):
|
|
'''
|
|
For backdoor attacks, we will choose different sampling stratergies.
|
|
fix-frequency, all-round ,or random sampling.
|
|
'''
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
unseen_clients_id=None,
|
|
**kwargs):
|
|
super(BackdoorServer, self).__init__(ID=ID,
|
|
state=state,
|
|
data=data,
|
|
model=model,
|
|
config=config,
|
|
client_num=client_num,
|
|
total_round_num=total_round_num,
|
|
device=device,
|
|
strategy=strategy,
|
|
**kwargs)
|
|
|
|
def broadcast_model_para(self,
|
|
msg_type='model_para',
|
|
sample_client_num=-1,
|
|
filter_unseen_clients=True):
|
|
"""
|
|
To broadcast the message to all clients or sampled clients
|
|
|
|
Arguments:
|
|
msg_type: 'model_para' or other user defined msg_type
|
|
sample_client_num: the number of sampled clients in the broadcast
|
|
behavior. And sample_client_num = -1 denotes to broadcast to
|
|
all the clients.
|
|
filter_unseen_clients: whether filter out the unseen clients that
|
|
do not contribute to FL process by training on their local
|
|
data and uploading their local model update. The splitting is
|
|
useful to check participation generalization gap in [ICLR'22,
|
|
What Do We Mean by Generalization in Federated Learning?]
|
|
You may want to set it to be False when in evaluation stage
|
|
"""
|
|
|
|
if filter_unseen_clients:
|
|
# to filter out the unseen clients when sampling
|
|
self.sampler.change_state(self.unseen_clients_id, 'unseen')
|
|
|
|
if sample_client_num > 0: # only activated at training process
|
|
attacker_id = self._cfg.attack.attacker_id
|
|
setting = self._cfg.attack.setting
|
|
insert_round = self._cfg.attack.insert_round
|
|
|
|
if attacker_id == -1 or self._cfg.attack.attack_method == '':
|
|
|
|
receiver = np.random.choice(np.arange(1, self.client_num + 1),
|
|
size=sample_client_num,
|
|
replace=False).tolist()
|
|
|
|
elif setting == 'fix':
|
|
if self.state % self._cfg.attack.freq == 0:
|
|
client_list = np.delete(np.arange(1, self.client_num + 1),
|
|
self._cfg.attack.attacker_id - 1)
|
|
receiver = np.random.choice(client_list,
|
|
size=sample_client_num - 1,
|
|
replace=False).tolist()
|
|
receiver.insert(0, self._cfg.attack.attacker_id)
|
|
logger.info('starting the fix-frequency poisoning attack')
|
|
logger.info(
|
|
'starting poisoning round: {:d}, the attacker ID: {:d}'
|
|
.format(self.state, self._cfg.attack.attacker_id))
|
|
else:
|
|
client_list = np.delete(np.arange(1, self.client_num + 1),
|
|
self._cfg.attack.attacker_id - 1)
|
|
receiver = np.random.choice(client_list,
|
|
size=sample_client_num,
|
|
replace=False).tolist()
|
|
|
|
elif setting == 'single' and self.state == insert_round:
|
|
client_list = np.delete(np.arange(1, self.client_num + 1),
|
|
self._cfg.attack.attacker_id - 1)
|
|
receiver = np.random.choice(client_list,
|
|
size=sample_client_num - 1,
|
|
replace=False).tolist()
|
|
receiver.insert(0, self._cfg.attack.attacker_id)
|
|
logger.info('starting the single-shot poisoning attack')
|
|
logger.info(
|
|
'starting poisoning round: {:d}, the attacker ID: {:d}'.
|
|
format(self.state, self._cfg.attack.attacker_id))
|
|
|
|
elif self._cfg.attack.setting == 'all':
|
|
|
|
client_list = np.delete(np.arange(1, self.client_num + 1),
|
|
self._cfg.attack.attacker_id - 1)
|
|
receiver = np.random.choice(client_list,
|
|
size=sample_client_num - 1,
|
|
replace=False).tolist()
|
|
receiver.insert(0, self._cfg.attack.attacker_id)
|
|
logger.info('starting the all-round poisoning attack')
|
|
logger.info(
|
|
'starting poisoning round: {:d}, the attacker ID: {:d}'.
|
|
format(self.state, self._cfg.attack.attacker_id))
|
|
|
|
else:
|
|
receiver = np.random.choice(np.arange(1, self.client_num + 1),
|
|
size=sample_client_num,
|
|
replace=False).tolist()
|
|
|
|
else:
|
|
# broadcast to all clients
|
|
receiver = list(self.comm_manager.neighbors.keys())
|
|
|
|
if self._noise_injector is not None and msg_type == 'model_para':
|
|
# Inject noise only when broadcast parameters
|
|
for model_idx_i in range(len(self.models)):
|
|
num_sample_clients = [
|
|
v["num_sample"] for v in self.join_in_info.values()
|
|
]
|
|
self._noise_injector(self._cfg, num_sample_clients,
|
|
self.models[model_idx_i])
|
|
|
|
skip_broadcast = self._cfg.federate.method in ["local", "global"]
|
|
if self.model_num > 1:
|
|
model_para = [{} if skip_broadcast else model.state_dict()
|
|
for model in self.models]
|
|
else:
|
|
model_para = {} if skip_broadcast else self.model.state_dict()
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type=msg_type,
|
|
sender=self.ID,
|
|
receiver=receiver,
|
|
state=min(self.state, self.total_round_num),
|
|
content=model_para))
|
|
if self._cfg.federate.online_aggr:
|
|
for idx in range(self.model_num):
|
|
self.aggregators[idx].reset()
|
|
|
|
if filter_unseen_clients:
|
|
# restore the state of the unseen clients within sampler
|
|
self.sampler.change_state(self.unseen_clients_id, 'seen')
|
|
|
|
|
|
class PassiveServer(Server):
|
|
'''
|
|
In passive attack, the server store the model and the message collected
|
|
from the client,and perform the optimization based reconstruction,
|
|
such as DLG, InvertGradient.
|
|
'''
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
state_to_reconstruct=None,
|
|
client_to_reconstruct=None,
|
|
**kwargs):
|
|
super(PassiveServer, self).__init__(ID=ID,
|
|
state=state,
|
|
data=data,
|
|
model=model,
|
|
client_num=client_num,
|
|
total_round_num=total_round_num,
|
|
device=device,
|
|
strategy=strategy,
|
|
**kwargs)
|
|
|
|
# self.offline_reconstruct = offline_reconstruct
|
|
self.atk_method = self._cfg.attack.attack_method
|
|
self.state_to_reconstruct = state_to_reconstruct
|
|
self.client_to_reconstruct = client_to_reconstruct
|
|
self.reconstruct_data = dict()
|
|
|
|
# the loss function of the global model; the global model can be
|
|
# obtained in self.aggregator.model
|
|
self.model_criterion = get_criterion(self._cfg.criterion.type,
|
|
device=self.device)
|
|
|
|
from federatedscope.attack.auxiliary.utils import get_data_info
|
|
self.data_dim, self.num_class, self.is_one_hot_label = get_data_info(
|
|
self._cfg.data.type)
|
|
|
|
self.reconstructor = self._get_reconstructor()
|
|
|
|
self.reconstructed_data_sav_fn = get_data_sav_fn(self._cfg.data.type)
|
|
|
|
self.reconstruct_data_summary = dict()
|
|
|
|
def _get_reconstructor(self):
|
|
|
|
return get_reconstructor(
|
|
self.atk_method,
|
|
max_ite=self._cfg.attack.max_ite,
|
|
lr=self._cfg.attack.reconstruct_lr,
|
|
federate_loss_fn=self.model_criterion,
|
|
device=self.device,
|
|
federate_lr=self._cfg.train.optimizer.lr,
|
|
optim=self._cfg.attack.reconstruct_optim,
|
|
info_diff_type=self._cfg.attack.info_diff_type,
|
|
federate_method=self._cfg.federate.method,
|
|
alpha_TV=self._cfg.attack.alpha_TV)
|
|
|
|
def _reconstruct(self, model_para, batch_size, state, sender):
|
|
logger.info('-------- reconstruct round:{}, client:{}---------'.format(
|
|
state, sender))
|
|
dummy_data, dummy_label = self.reconstructor.reconstruct(
|
|
model=copy.deepcopy(self.model).to(torch.device(self.device)),
|
|
original_info=model_para,
|
|
data_feature_dim=self.data_dim,
|
|
num_class=self.num_class,
|
|
batch_size=batch_size)
|
|
if state not in self.reconstruct_data.keys():
|
|
self.reconstruct_data[state] = dict()
|
|
self.reconstruct_data[state][sender] = [
|
|
dummy_data.cpu(), dummy_label.cpu()
|
|
]
|
|
|
|
def run_reconstruct(self, state_list=None, sender_list=None):
|
|
|
|
if state_list is None:
|
|
state_list = self.msg_buffer['train'].keys()
|
|
|
|
# After FL running, using gradient based reconstruction method to
|
|
# recover client's private training data
|
|
for state in state_list:
|
|
if sender_list is None:
|
|
sender_list = self.msg_buffer['train'][state].keys()
|
|
for sender in sender_list:
|
|
content = self.msg_buffer['train'][state][sender]
|
|
self._reconstruct(model_para=content[1],
|
|
batch_size=content[0],
|
|
state=state,
|
|
sender=sender)
|
|
|
|
def callback_funcs_model_para(self, message: Message):
|
|
if self.is_finish:
|
|
return 'finish'
|
|
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
if round not in self.msg_buffer['train']:
|
|
self.msg_buffer['train'][round] = dict()
|
|
self.msg_buffer['train'][round][sender] = content
|
|
|
|
# run reconstruction before the clear of self.msg_buffer
|
|
if 'DLG_loss' not in self.best_results.keys():
|
|
self.best_results['DLG_loss'] = {}
|
|
if round not in self.best_results['DLG_loss'].keys():
|
|
self.best_results['DLG_loss'][round] = {}
|
|
|
|
if self.state_to_reconstruct is None or message.state in \
|
|
self.state_to_reconstruct:
|
|
if self.client_to_reconstruct is None or message.sender in \
|
|
self.client_to_reconstruct:
|
|
self.run_reconstruct(state_list=[message.state],
|
|
sender_list=[message.sender])
|
|
logger.info(
|
|
'Finish DLG attack; Final DLG reconstruction loss: {}'.
|
|
format(self.reconstructor.dlg_recover_loss))
|
|
self.best_results['DLG_loss'][round][
|
|
sender] = self.reconstructor.dlg_recover_loss
|
|
if self.reconstructed_data_sav_fn is not None:
|
|
self.reconstructed_data_sav_fn(
|
|
data=self.reconstruct_data[message.state][
|
|
message.sender][0],
|
|
sav_pth=self._cfg.outdir,
|
|
name='image_state_{}_client_{}.png'.format(
|
|
message.state, message.sender))
|
|
|
|
self.check_and_move_on()
|
|
|
|
|
|
class PassivePIAServer(Server):
|
|
'''
|
|
The implementation of the batch property classifier, the algorithm 3 in
|
|
paper: Exploiting Unintended Feature Leakage in Collaborative Learning
|
|
|
|
References:
|
|
|
|
Melis, Luca, Congzheng Song, Emiliano De Cristofaro and Vitaly
|
|
Shmatikov. “Exploiting Unintended Feature Leakage in Collaborative
|
|
Learning.” 2019 IEEE Symposium on Security and Privacy (SP) (2019): 691-706
|
|
'''
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
super(PassivePIAServer, self).__init__(ID=ID,
|
|
state=state,
|
|
data=data,
|
|
model=model,
|
|
client_num=client_num,
|
|
total_round_num=total_round_num,
|
|
device=device,
|
|
strategy=strategy,
|
|
**kwargs)
|
|
|
|
# self.offline_reconstruct = offline_reconstruct
|
|
self.atk_method = self._cfg.attack.attack_method
|
|
self.pia_attacker = PassivePropertyInference(
|
|
classier=self._cfg.attack.classifier_PIA,
|
|
fl_model_criterion=get_criterion(self._cfg.criterion.type,
|
|
device=self.device),
|
|
device=self.device,
|
|
grad_clip=self._cfg.grad.grad_clip,
|
|
dataset_name=self._cfg.data.type,
|
|
fl_local_update_num=self._cfg.train.local_update_steps,
|
|
fl_type_optimizer=self._cfg.train.optimizer.type,
|
|
fl_lr=self._cfg.train.optimizer.lr,
|
|
batch_size=100)
|
|
|
|
# self.optimizer = get_optimizer(
|
|
# type=self._cfg.fedopt.type_optimizer, model=self.model,
|
|
# lr=self._cfg.fedopt.optimizer.lr)
|
|
# print(self.optimizer)
|
|
def callback_funcs_model_para(self, message: Message):
|
|
if self.is_finish:
|
|
return 'finish'
|
|
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
if round not in self.msg_buffer['train']:
|
|
self.msg_buffer['train'][round] = dict()
|
|
self.msg_buffer['train'][round][sender] = content
|
|
|
|
# collect the updates
|
|
self.pia_attacker.collect_updates(
|
|
previous_para=self.model.state_dict(),
|
|
updated_parameter=content[1],
|
|
round=round,
|
|
client_id=sender)
|
|
self.pia_attacker.get_data_for_dataset_prop_classifier(
|
|
model=self.model)
|
|
|
|
if self._cfg.federate.online_aggr:
|
|
# TODO: put this line to `check_and_move_on`
|
|
# currently, no way to know the latest `sender`
|
|
self.aggregator.inc(content)
|
|
self.check_and_move_on()
|
|
|
|
if self.state == self.total_round_num:
|
|
self.pia_attacker.train_property_classifier()
|
|
self.pia_results = self.pia_attacker.infer_collected()
|
|
print(self.pia_results)
|