296 lines
12 KiB
Python
296 lines
12 KiB
Python
import torch
|
|
from PIL import Image
|
|
import numpy as np
|
|
from torchvision import transforms
|
|
from federatedscope.core.auxiliaries.transform_builder import get_transform
|
|
from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger
|
|
from torch.utils.data import DataLoader
|
|
from federatedscope.attack.auxiliary.backdoor_utils import normalize
|
|
from federatedscope.core.trainers.enums import MODE
|
|
import pickle
|
|
import logging
|
|
import os
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def load_poisoned_dataset_edgeset(data, ctx, mode):
|
|
|
|
transforms_funcs, _, _ = get_transform(ctx, 'torchvision')['transform']
|
|
load_path = ctx.attack.edge_path
|
|
if "femnist" in ctx.data.type:
|
|
if mode == MODE.TRAIN:
|
|
train_path = os.path.join(load_path,
|
|
"poisoned_edgeset_fraction_0.1")
|
|
with open(train_path, "rb") as saved_data_file:
|
|
poisoned_edgeset = torch.load(saved_data_file)
|
|
num_dps_poisoned_dataset = len(poisoned_edgeset)
|
|
|
|
for ii in range(num_dps_poisoned_dataset):
|
|
sample, label = poisoned_edgeset[ii]
|
|
# (channel, height, width) = sample.shape #(c,h,w)
|
|
sample = sample.numpy().transpose(1, 2, 0)
|
|
data[mode].dataset.append((transforms_funcs(sample), label))
|
|
|
|
if mode == MODE.TEST or mode == MODE.VAL:
|
|
poison_testset = list()
|
|
test_path = os.path.join(load_path, 'ardis_test_dataset.pt')
|
|
with open(test_path) as saved_data_file:
|
|
poisoned_edgeset = torch.load(saved_data_file)
|
|
num_dps_poisoned_dataset = len(poisoned_edgeset)
|
|
|
|
for ii in range(num_dps_poisoned_dataset):
|
|
sample, label = poisoned_edgeset[ii]
|
|
# (channel, height, width) = sample.shape #(c,h,w)
|
|
sample = sample.numpy().transpose(1, 2, 0)
|
|
poison_testset.append((transforms_funcs(sample), label))
|
|
data['poison_' + mode] = DataLoader(
|
|
poison_testset,
|
|
batch_size=ctx.dataloader.batch_size,
|
|
shuffle=False,
|
|
num_workers=ctx.dataloader.num_workers)
|
|
|
|
elif "CIFAR10" in ctx.data.type:
|
|
target_label = int(ctx.attack.target_label_ind)
|
|
label = target_label
|
|
num_poisoned = ctx.attack.edge_num
|
|
if mode == MODE.TRAIN:
|
|
train_path = os.path.join(load_path,
|
|
'southwest_images_new_train.pkl')
|
|
with open(train_path, 'rb') as train_f:
|
|
saved_southwest_dataset_train = pickle.load(train_f)
|
|
num_poisoned_dataset = num_poisoned
|
|
samped_poisoned_data_indices = np.random.choice(
|
|
saved_southwest_dataset_train.shape[0],
|
|
num_poisoned_dataset,
|
|
replace=False)
|
|
saved_southwest_dataset_train = saved_southwest_dataset_train[
|
|
samped_poisoned_data_indices, :, :, :]
|
|
|
|
for ii in range(num_poisoned_dataset):
|
|
sample = saved_southwest_dataset_train[ii]
|
|
data[mode].dataset.append((transforms_funcs(sample), label))
|
|
|
|
logger.info('adding {:d} edge-cased samples in CIFAR-10'.format(
|
|
num_poisoned))
|
|
|
|
if mode == MODE.TEST or mode == MODE.VAL:
|
|
poison_testset = list()
|
|
test_path = os.path.join(load_path,
|
|
'southwest_images_new_test.pkl')
|
|
with open(test_path, 'rb') as test_f:
|
|
saved_southwest_dataset_test = pickle.load(test_f)
|
|
num_poisoned_dataset = len(saved_southwest_dataset_test)
|
|
|
|
for ii in range(num_poisoned_dataset):
|
|
sample = saved_southwest_dataset_test[ii]
|
|
poison_testset.append((transforms_funcs(sample), label))
|
|
data['poison_' + mode] = DataLoader(
|
|
poison_testset,
|
|
batch_size=ctx.dataloader.batch_size,
|
|
shuffle=False,
|
|
num_workers=ctx.dataloader.num_workers)
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
'Now, we only support the FEMNIST and CIFAR-10 datasets')
|
|
|
|
return data
|
|
|
|
|
|
def addTrigger(dataset,
|
|
target_label,
|
|
inject_portion,
|
|
mode,
|
|
distance,
|
|
trig_h,
|
|
trig_w,
|
|
trigger_type,
|
|
label_type,
|
|
surrogate_model=None,
|
|
load_path=None):
|
|
|
|
height = dataset[0][0].shape[-2]
|
|
width = dataset[0][0].shape[-1]
|
|
trig_h = int(trig_h * height)
|
|
trig_w = int(trig_w * width)
|
|
|
|
if 'wanet' in trigger_type:
|
|
cross_portion = 2 # default val following the original paper
|
|
perm_then = np.random.permutation(
|
|
len(dataset
|
|
))[0:int(len(dataset) * inject_portion * (1 + cross_portion))]
|
|
perm = perm_then[0:int(len(dataset) * inject_portion)]
|
|
perm_cross = perm_then[(
|
|
int(len(dataset) * inject_portion) +
|
|
1):int(len(dataset) * inject_portion * (1 + cross_portion))]
|
|
else:
|
|
perm = np.random.permutation(
|
|
len(dataset))[0:int(len(dataset) * inject_portion)]
|
|
|
|
dataset_ = list()
|
|
for i in range(len(dataset)):
|
|
data = dataset[i]
|
|
|
|
if label_type == 'dirty':
|
|
# all2one attack
|
|
if mode == MODE.TRAIN:
|
|
img = np.array(data[0]).transpose(1, 2, 0) * 255.0
|
|
img = np.clip(img.astype('uint8'), 0, 255)
|
|
height = img.shape[0]
|
|
width = img.shape[1]
|
|
|
|
if i in perm:
|
|
img = selectTrigger(img, height, width, distance, trig_h,
|
|
trig_w, trigger_type, load_path)
|
|
|
|
dataset_.append((img, target_label))
|
|
|
|
elif 'wanet' in trigger_type and i in perm_cross:
|
|
img = selectTrigger(img, width, height, distance, trig_w,
|
|
trig_h, 'wanetTriggerCross', load_path)
|
|
dataset_.append((img, data[1]))
|
|
|
|
else:
|
|
dataset_.append((img, data[1]))
|
|
|
|
if mode == MODE.TEST or mode == MODE.VAL:
|
|
if data[1] == target_label:
|
|
continue
|
|
|
|
img = np.array(data[0]).transpose(1, 2, 0) * 255.0
|
|
img = np.clip(img.astype('uint8'), 0, 255)
|
|
height = img.shape[0]
|
|
width = img.shape[1]
|
|
if i in perm:
|
|
img = selectTrigger(img, width, height, distance, trig_w,
|
|
trig_h, trigger_type, load_path)
|
|
dataset_.append((img, target_label))
|
|
else:
|
|
dataset_.append((img, data[1]))
|
|
|
|
elif label_type == 'clean_label':
|
|
pass
|
|
|
|
return dataset_
|
|
|
|
|
|
def load_poisoned_dataset_pixel(data, ctx, mode):
|
|
|
|
trigger_type = ctx.attack.trigger_type
|
|
label_type = ctx.attack.label_type
|
|
target_label = int(ctx.attack.target_label_ind)
|
|
transforms_funcs = get_transform(ctx, 'torchvision')[0]['transform']
|
|
|
|
if "femnist" in ctx.data.type or "CIFAR10" in ctx.data.type:
|
|
inject_portion_train = ctx.attack.poison_ratio
|
|
else:
|
|
raise RuntimeError(
|
|
'Now, we only support the FEMNIST and CIFAR-10 datasets')
|
|
|
|
inject_portion_test = 1.0
|
|
|
|
load_path = ctx.attack.trigger_path
|
|
|
|
if mode == MODE.TRAIN:
|
|
poisoned_dataset = addTrigger(data[mode].dataset,
|
|
target_label,
|
|
inject_portion_train,
|
|
mode=mode,
|
|
distance=1,
|
|
trig_h=0.1,
|
|
trig_w=0.1,
|
|
trigger_type=trigger_type,
|
|
label_type=label_type,
|
|
load_path=load_path)
|
|
num_dps_poisoned_dataset = len(poisoned_dataset)
|
|
for iii in range(num_dps_poisoned_dataset):
|
|
sample, label = poisoned_dataset[iii]
|
|
poisoned_dataset[iii] = (transforms_funcs(sample), label)
|
|
|
|
data[mode] = DataLoader(poisoned_dataset,
|
|
batch_size=ctx.dataloader.batch_size,
|
|
shuffle=True,
|
|
num_workers=ctx.dataloader.num_workers)
|
|
|
|
if mode == MODE.TEST or mode == MODE.VAL:
|
|
poisoned_dataset = addTrigger(data[mode].dataset,
|
|
target_label,
|
|
inject_portion_test,
|
|
mode=mode,
|
|
distance=1,
|
|
trig_h=0.1,
|
|
trig_w=0.1,
|
|
trigger_type=trigger_type,
|
|
label_type=label_type,
|
|
load_path=load_path)
|
|
num_dps_poisoned_dataset = len(poisoned_dataset)
|
|
for iii in range(num_dps_poisoned_dataset):
|
|
sample, label = poisoned_dataset[iii]
|
|
# (channel, height, width) = sample.shape #(c,h,w)
|
|
poisoned_dataset[iii] = (transforms_funcs(sample), label)
|
|
|
|
data['poison_' + mode] = DataLoader(
|
|
poisoned_dataset,
|
|
batch_size=ctx.dataloader.batch_size,
|
|
shuffle=False,
|
|
num_workers=ctx.dataloader.num_workers)
|
|
|
|
return data
|
|
|
|
|
|
def add_trans_normalize(data, ctx):
|
|
'''
|
|
data for each client is a dictionary.
|
|
'''
|
|
|
|
for key in data:
|
|
num_dataset = len(data[key].dataset)
|
|
mean, std = ctx.attack.mean, ctx.attack.std
|
|
if "CIFAR10" in ctx.data.type and key == MODE.TRAIN:
|
|
transforms_list = []
|
|
transforms_list.append(transforms.RandomHorizontalFlip())
|
|
transforms_list.append(transforms.ToTensor())
|
|
tran_train = transforms.Compose(transforms_list)
|
|
for iii in range(num_dataset):
|
|
sample = np.array(data[key].dataset[iii][0]).transpose(
|
|
1, 2, 0) * 255.0
|
|
sample = np.clip(sample.astype('uint8'), 0, 255)
|
|
sample = Image.fromarray(sample)
|
|
sample = tran_train(sample)
|
|
data[key].dataset[iii] = (normalize(sample, mean, std),
|
|
data[key].dataset[iii][1])
|
|
else:
|
|
for iii in range(num_dataset):
|
|
data[key].dataset[iii] = (normalize(data[key].dataset[iii][0],
|
|
mean, std),
|
|
data[key].dataset[iii][1])
|
|
|
|
return data
|
|
|
|
|
|
def select_poisoning(data, ctx, mode):
|
|
|
|
if 'edge' in ctx.attack.trigger_type:
|
|
data = load_poisoned_dataset_edgeset(data, ctx, mode)
|
|
elif 'semantic' in ctx.attack.trigger_type:
|
|
pass
|
|
else:
|
|
data = load_poisoned_dataset_pixel(data, ctx, mode)
|
|
return data
|
|
|
|
|
|
def poisoning(data, ctx):
|
|
for i in range(1, len(data) + 1):
|
|
if i == ctx.attack.attacker_id:
|
|
logger.info(50 * '-')
|
|
logger.info('start poisoning at Client: {}'.format(i))
|
|
logger.info(50 * '-')
|
|
data[i] = select_poisoning(data[i], ctx, mode=MODE.TRAIN)
|
|
data[i] = select_poisoning(data[i], ctx, mode=MODE.TEST)
|
|
if data[i].get(MODE.VAL):
|
|
data[i] = select_poisoning(data[i], ctx, mode=MODE.VAL)
|
|
data[i] = add_trans_normalize(data[i], ctx)
|
|
logger.info('finishing the clean and {} poisoning data processing \
|
|
for Client {:d}'.format(ctx.attack.trigger_type, i))
|