355 lines
10 KiB
Python
355 lines
10 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import matplotlib.pyplot as plt
|
|
import logging
|
|
import os
|
|
import numpy as np
|
|
import federatedscope.register as register
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def label_to_onehot(target, num_classes=100):
|
|
return torch.nn.functional.one_hot(target, num_classes)
|
|
|
|
|
|
def cross_entropy_for_onehot(pred, target):
|
|
return torch.mean(torch.sum(-target * F.log_softmax(pred, dim=-1), 1))
|
|
|
|
|
|
def iDLG_trick(original_gradient, num_class, is_one_hot_label=False):
|
|
'''
|
|
Using iDLG trick to recover the label. Paper: "iDLG: Improved Deep
|
|
Leakage from Gradients", link: https://arxiv.org/abs/2001.02610
|
|
|
|
Args:
|
|
original_gradient: the gradient of the FL model; type: list
|
|
num_class: the total number of class in the data
|
|
is_one_hot_label: whether the dataset's label is in the form of one
|
|
hot. Type: bool
|
|
|
|
Returns:
|
|
The recovered label by iDLG trick.
|
|
|
|
'''
|
|
last_weight_min = torch.argmin(torch.sum(original_gradient[-2], dim=-1),
|
|
dim=-1).detach()
|
|
|
|
if is_one_hot_label:
|
|
label = label_to_onehot(
|
|
last_weight_min.reshape((1, )).requires_grad_(False), num_class)
|
|
else:
|
|
label = last_weight_min
|
|
return label
|
|
|
|
|
|
def cos_sim(input_gradient, gt_gradient):
|
|
total = 1 - torch.nn.functional.cosine_similarity(
|
|
input_gradient.flatten(), gt_gradient.flatten(), 0, 1e-10)
|
|
|
|
# total = 0
|
|
# input_norm= 0
|
|
# gt_norm = 0
|
|
#
|
|
# total -= (input_gradient * gt_gradient).sum()
|
|
# input_norm += input_gradient.pow(2).sum()
|
|
# gt_norm += gt_gradient.pow(2).sum()
|
|
# total += 1 + total / input_norm.sqrt() / gt_norm.sqrt()
|
|
|
|
return total
|
|
|
|
|
|
def total_variation(x):
|
|
"""Anisotropic TV."""
|
|
dx = torch.mean(torch.abs(x[:, :, :, :-1] - x[:, :, :, 1:]))
|
|
dy = torch.mean(torch.abs(x[:, :, :-1, :] - x[:, :, 1:, :]))
|
|
|
|
total = x.size()[0]
|
|
for ind in range(1, len(x.size())):
|
|
total *= x.size()[ind]
|
|
return (dx + dy) / (total)
|
|
|
|
|
|
def approximate_func(x, device, C1=20, C2=0.5):
|
|
'''
|
|
Approximate the function f(x) = 0 if x<0.5 otherwise 1
|
|
Args:
|
|
x: input data;
|
|
device:
|
|
C1:
|
|
C2:
|
|
|
|
Returns:
|
|
1/(1+e^{-1*C1 (x-C2)})
|
|
|
|
'''
|
|
C1 = torch.tensor(C1).to(torch.device(device))
|
|
C2 = torch.tensor(C2).to(torch.device(device))
|
|
|
|
return 1 / (1 + torch.exp(-1 * C1 * (x - C2)))
|
|
|
|
|
|
def get_classifier(classifier: str, model=None):
|
|
if model is not None:
|
|
return model
|
|
|
|
if classifier == 'lr':
|
|
from sklearn.linear_model import LogisticRegression
|
|
model = LogisticRegression(random_state=0)
|
|
return model
|
|
elif classifier.lower() == 'randomforest':
|
|
from sklearn.ensemble import RandomForestClassifier
|
|
model = RandomForestClassifier(random_state=0)
|
|
return model
|
|
elif classifier.lower() == 'svm':
|
|
from sklearn.svm import SVC
|
|
from sklearn.preprocessing import StandardScaler
|
|
from sklearn.pipeline import make_pipeline
|
|
model = make_pipeline(StandardScaler(), SVC(gamma='auto'))
|
|
return model
|
|
else:
|
|
ValueError()
|
|
|
|
|
|
def get_data_info(dataset_name):
|
|
'''
|
|
Get the dataset information, including the feature dimension, number of
|
|
total classes, whether the label is represented in one-hot version
|
|
|
|
Args:
|
|
dataset_name:dataset name; str
|
|
|
|
:returns:
|
|
data_feature_dim, num_class, is_one_hot_label
|
|
|
|
'''
|
|
if dataset_name.lower() == 'femnist':
|
|
|
|
return [1, 28, 28], 36, False
|
|
else:
|
|
ValueError(
|
|
'Please provide the data info of {}: data_feature_dim, num_class'.
|
|
format(dataset_name))
|
|
|
|
|
|
def get_data_sav_fn(dataset_name):
|
|
if dataset_name.lower() == 'femnist':
|
|
return sav_femnist_image
|
|
else:
|
|
logger.info(f"Reconstructed data saving function is not provided for "
|
|
f"dataset: {dataset_name}")
|
|
return None
|
|
|
|
|
|
def sav_femnist_image(data, sav_pth, name):
|
|
|
|
_ = plt.figure(figsize=(4, 4))
|
|
# print(data.shape)
|
|
|
|
if len(data.shape) == 2:
|
|
data = torch.unsqueeze(data, 0)
|
|
data = torch.unsqueeze(data, 0)
|
|
|
|
ind = min(data.shape[0], 16)
|
|
# print(data.shape)
|
|
|
|
# plt.imshow(data * 127.5 + 127.5, cmap='gray')
|
|
|
|
for i in range(ind):
|
|
plt.subplot(4, 4, i + 1)
|
|
|
|
plt.imshow(data[i, 0, :, :] * 127.5 + 127.5, cmap='gray')
|
|
# plt.imshow(generated_data[i, 0, :, :] , cmap='gray')
|
|
# plt.imshow()
|
|
plt.axis('off')
|
|
|
|
plt.savefig(os.path.join(sav_pth, name))
|
|
plt.close()
|
|
|
|
|
|
def get_info_diff_loss(info_diff_type):
|
|
if info_diff_type.lower() == 'l2':
|
|
info_diff_loss = torch.nn.MSELoss(reduction='sum')
|
|
elif info_diff_type.lower() == 'l1':
|
|
info_diff_loss = torch.nn.SmoothL1Loss(reduction='sum', beta=1e-5)
|
|
elif info_diff_type.lower() == 'sim':
|
|
info_diff_loss = cos_sim
|
|
else:
|
|
ValueError(
|
|
'info_diff_type: {} is not supported'.format(info_diff_type))
|
|
return info_diff_loss
|
|
|
|
|
|
def get_reconstructor(atk_method, **kwargs):
|
|
'''
|
|
|
|
Args:
|
|
atk_method: the attack method name, and currently supporting "DLG:
|
|
deep leakage from gradient", and "IG: Inverting gradient" ; Type: str
|
|
**kwargs: other arguments
|
|
|
|
Returns:
|
|
|
|
'''
|
|
|
|
if atk_method.lower() == 'dlg':
|
|
from federatedscope.attack.privacy_attacks.reconstruction_opt import\
|
|
DLG
|
|
logger.info(
|
|
'--------- Getting reconstructor: DLG --------------------')
|
|
|
|
return DLG(max_ite=kwargs['max_ite'],
|
|
lr=kwargs['lr'],
|
|
federate_loss_fn=kwargs['federate_loss_fn'],
|
|
device=kwargs['device'],
|
|
federate_lr=kwargs['federate_lr'],
|
|
optim=kwargs['optim'],
|
|
info_diff_type=kwargs['info_diff_type'],
|
|
federate_method=kwargs['federate_method'])
|
|
elif atk_method.lower() == 'ig':
|
|
from federatedscope.attack.privacy_attacks.reconstruction_opt import\
|
|
InvertGradient
|
|
logger.info(
|
|
'------- Getting reconstructor: InvertGradient ------------------')
|
|
return InvertGradient(max_ite=kwargs['max_ite'],
|
|
lr=kwargs['lr'],
|
|
federate_loss_fn=kwargs['federate_loss_fn'],
|
|
device=kwargs['device'],
|
|
federate_lr=kwargs['federate_lr'],
|
|
optim=kwargs['optim'],
|
|
info_diff_type=kwargs['info_diff_type'],
|
|
federate_method=kwargs['federate_method'],
|
|
alpha_TV=kwargs['alpha_TV'])
|
|
else:
|
|
ValueError(
|
|
"attack method: {} lacks reconstructor implementation".format(
|
|
atk_method))
|
|
|
|
|
|
def get_generator(dataset_name):
|
|
'''
|
|
Get the dataset's corresponding generator.
|
|
Args:
|
|
dataset_name: The dataset name; Type: str
|
|
|
|
:returns:
|
|
The generator; Type: object
|
|
|
|
'''
|
|
if dataset_name == 'femnist':
|
|
from federatedscope.attack.models.gan_based_model import \
|
|
GeneratorFemnist
|
|
return GeneratorFemnist
|
|
else:
|
|
ValueError(
|
|
"The generator to generate data like {} is not defined!".format(
|
|
dataset_name))
|
|
|
|
|
|
def get_data_property(ctx):
|
|
# A SHOWCASE for Femnist dataset: Property := whether contains a circle.
|
|
x, label = [_.to(ctx.device) for _ in ctx.data_batch]
|
|
|
|
prop = torch.zeros(label.size)
|
|
positive_labels = [0, 6, 8]
|
|
for ind in range(label.size()[0]):
|
|
if label[ind] in positive_labels:
|
|
prop[ind] = 1
|
|
prop.to(ctx.device)
|
|
return prop
|
|
|
|
|
|
def get_passive_PIA_auxiliary_dataset(dataset_name):
|
|
'''
|
|
|
|
Args:
|
|
dataset_name (str): dataset name
|
|
|
|
:returns:
|
|
|
|
the auxiliary dataset for property inference attack. Type: dict
|
|
|
|
{
|
|
'x': array,
|
|
'y': array,
|
|
'prop': array
|
|
}
|
|
|
|
'''
|
|
for func in register.auxiliary_data_loader_PIA_dict.values():
|
|
criterion = func(dataset_name)
|
|
if criterion is not None:
|
|
return criterion
|
|
if dataset_name == 'toy':
|
|
|
|
def _generate_data(instance_num=1000, feature_num=5, save_data=False):
|
|
"""
|
|
Generate data in Runner format
|
|
Args:
|
|
instance_num:
|
|
feature_num:
|
|
save_data:
|
|
|
|
Returns:
|
|
{
|
|
'x': ...,
|
|
'y': ...,
|
|
'prop': ...
|
|
}
|
|
|
|
"""
|
|
weights = np.random.normal(loc=0.0, scale=1.0, size=feature_num)
|
|
bias = np.random.normal(loc=0.0, scale=1.0)
|
|
|
|
prop_weights = np.random.normal(loc=0.0,
|
|
scale=1.0,
|
|
size=feature_num)
|
|
prop_bias = np.random.normal(loc=0.0, scale=1.0)
|
|
|
|
x = np.random.normal(loc=0.0,
|
|
scale=0.5,
|
|
size=(instance_num, feature_num))
|
|
y = np.sum(x * weights, axis=-1) + bias
|
|
y = np.expand_dims(y, -1)
|
|
prop = np.sum(x * prop_weights, axis=-1) + prop_bias
|
|
prop = 1.0 * ((1 / (1 + np.exp(-1 * prop))) > 0.5)
|
|
prop = np.expand_dims(prop, -1)
|
|
|
|
data_train = {'x': x, 'y': y, 'prop': prop}
|
|
return data_train
|
|
|
|
return _generate_data()
|
|
else:
|
|
ValueError(
|
|
'The data cannot be loaded. Please specify the data load function.'
|
|
)
|
|
|
|
|
|
def plot_mia_loss_compare(loss_in_pth, loss_out_pth, in_round=20):
|
|
loss_in = np.loadtxt(loss_in_pth, delimiter=',')
|
|
loss_out = np.loadtxt(loss_out_pth, delimiter=',')
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
loss_in_all = []
|
|
loss_out_all = []
|
|
for i in range(len(loss_in)):
|
|
if i == in_round:
|
|
pass
|
|
else:
|
|
loss_in_all.append(loss_in[i])
|
|
loss_out_all.append(loss_out[i])
|
|
|
|
plt.plot(loss_out_all, label='not-in', alpha=0.9, color='red', linewidth=2)
|
|
plt.plot(loss_in_all,
|
|
linestyle=':',
|
|
label='in',
|
|
alpha=0.9,
|
|
linewidth=2,
|
|
color='blue')
|
|
|
|
plt.legend()
|
|
plt.xlabel('Round', fontsize=16)
|
|
plt.ylabel('$L_x$', fontsize=16)
|
|
plt.show()
|