FS-TFP/federatedscope/attack/auxiliary/poisoning_data.py

296 lines
12 KiB
Python

import torch
from PIL import Image
import numpy as np
from torchvision import transforms
from federatedscope.core.auxiliaries.transform_builder import get_transform
from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger
from torch.utils.data import DataLoader
from federatedscope.attack.auxiliary.backdoor_utils import normalize
from federatedscope.core.trainers.enums import MODE
import pickle
import logging
import os
logger = logging.getLogger(__name__)
def load_poisoned_dataset_edgeset(data, ctx, mode):
transforms_funcs, _, _ = get_transform(ctx, 'torchvision')['transform']
load_path = ctx.attack.edge_path
if "femnist" in ctx.data.type:
if mode == MODE.TRAIN:
train_path = os.path.join(load_path,
"poisoned_edgeset_fraction_0.1")
with open(train_path, "rb") as saved_data_file:
poisoned_edgeset = torch.load(saved_data_file)
num_dps_poisoned_dataset = len(poisoned_edgeset)
for ii in range(num_dps_poisoned_dataset):
sample, label = poisoned_edgeset[ii]
# (channel, height, width) = sample.shape #(c,h,w)
sample = sample.numpy().transpose(1, 2, 0)
data[mode].dataset.append((transforms_funcs(sample), label))
if mode == MODE.TEST or mode == MODE.VAL:
poison_testset = list()
test_path = os.path.join(load_path, 'ardis_test_dataset.pt')
with open(test_path) as saved_data_file:
poisoned_edgeset = torch.load(saved_data_file)
num_dps_poisoned_dataset = len(poisoned_edgeset)
for ii in range(num_dps_poisoned_dataset):
sample, label = poisoned_edgeset[ii]
# (channel, height, width) = sample.shape #(c,h,w)
sample = sample.numpy().transpose(1, 2, 0)
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.dataloader.num_workers)
elif "CIFAR10" in ctx.data.type:
target_label = int(ctx.attack.target_label_ind)
label = target_label
num_poisoned = ctx.attack.edge_num
if mode == MODE.TRAIN:
train_path = os.path.join(load_path,
'southwest_images_new_train.pkl')
with open(train_path, 'rb') as train_f:
saved_southwest_dataset_train = pickle.load(train_f)
num_poisoned_dataset = num_poisoned
samped_poisoned_data_indices = np.random.choice(
saved_southwest_dataset_train.shape[0],
num_poisoned_dataset,
replace=False)
saved_southwest_dataset_train = saved_southwest_dataset_train[
samped_poisoned_data_indices, :, :, :]
for ii in range(num_poisoned_dataset):
sample = saved_southwest_dataset_train[ii]
data[mode].dataset.append((transforms_funcs(sample), label))
logger.info('adding {:d} edge-cased samples in CIFAR-10'.format(
num_poisoned))
if mode == MODE.TEST or mode == MODE.VAL:
poison_testset = list()
test_path = os.path.join(load_path,
'southwest_images_new_test.pkl')
with open(test_path, 'rb') as test_f:
saved_southwest_dataset_test = pickle.load(test_f)
num_poisoned_dataset = len(saved_southwest_dataset_test)
for ii in range(num_poisoned_dataset):
sample = saved_southwest_dataset_test[ii]
poison_testset.append((transforms_funcs(sample), label))
data['poison_' + mode] = DataLoader(
poison_testset,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.dataloader.num_workers)
else:
raise RuntimeError(
'Now, we only support the FEMNIST and CIFAR-10 datasets')
return data
def addTrigger(dataset,
target_label,
inject_portion,
mode,
distance,
trig_h,
trig_w,
trigger_type,
label_type,
surrogate_model=None,
load_path=None):
height = dataset[0][0].shape[-2]
width = dataset[0][0].shape[-1]
trig_h = int(trig_h * height)
trig_w = int(trig_w * width)
if 'wanet' in trigger_type:
cross_portion = 2 # default val following the original paper
perm_then = np.random.permutation(
len(dataset
))[0:int(len(dataset) * inject_portion * (1 + cross_portion))]
perm = perm_then[0:int(len(dataset) * inject_portion)]
perm_cross = perm_then[(
int(len(dataset) * inject_portion) +
1):int(len(dataset) * inject_portion * (1 + cross_portion))]
else:
perm = np.random.permutation(
len(dataset))[0:int(len(dataset) * inject_portion)]
dataset_ = list()
for i in range(len(dataset)):
data = dataset[i]
if label_type == 'dirty':
# all2one attack
if mode == MODE.TRAIN:
img = np.array(data[0]).transpose(1, 2, 0) * 255.0
img = np.clip(img.astype('uint8'), 0, 255)
height = img.shape[0]
width = img.shape[1]
if i in perm:
img = selectTrigger(img, height, width, distance, trig_h,
trig_w, trigger_type, load_path)
dataset_.append((img, target_label))
elif 'wanet' in trigger_type and i in perm_cross:
img = selectTrigger(img, width, height, distance, trig_w,
trig_h, 'wanetTriggerCross', load_path)
dataset_.append((img, data[1]))
else:
dataset_.append((img, data[1]))
if mode == MODE.TEST or mode == MODE.VAL:
if data[1] == target_label:
continue
img = np.array(data[0]).transpose(1, 2, 0) * 255.0
img = np.clip(img.astype('uint8'), 0, 255)
height = img.shape[0]
width = img.shape[1]
if i in perm:
img = selectTrigger(img, width, height, distance, trig_w,
trig_h, trigger_type, load_path)
dataset_.append((img, target_label))
else:
dataset_.append((img, data[1]))
elif label_type == 'clean_label':
pass
return dataset_
def load_poisoned_dataset_pixel(data, ctx, mode):
trigger_type = ctx.attack.trigger_type
label_type = ctx.attack.label_type
target_label = int(ctx.attack.target_label_ind)
transforms_funcs = get_transform(ctx, 'torchvision')[0]['transform']
if "femnist" in ctx.data.type or "CIFAR10" in ctx.data.type:
inject_portion_train = ctx.attack.poison_ratio
else:
raise RuntimeError(
'Now, we only support the FEMNIST and CIFAR-10 datasets')
inject_portion_test = 1.0
load_path = ctx.attack.trigger_path
if mode == MODE.TRAIN:
poisoned_dataset = addTrigger(data[mode].dataset,
target_label,
inject_portion_train,
mode=mode,
distance=1,
trig_h=0.1,
trig_w=0.1,
trigger_type=trigger_type,
label_type=label_type,
load_path=load_path)
num_dps_poisoned_dataset = len(poisoned_dataset)
for iii in range(num_dps_poisoned_dataset):
sample, label = poisoned_dataset[iii]
poisoned_dataset[iii] = (transforms_funcs(sample), label)
data[mode] = DataLoader(poisoned_dataset,
batch_size=ctx.dataloader.batch_size,
shuffle=True,
num_workers=ctx.dataloader.num_workers)
if mode == MODE.TEST or mode == MODE.VAL:
poisoned_dataset = addTrigger(data[mode].dataset,
target_label,
inject_portion_test,
mode=mode,
distance=1,
trig_h=0.1,
trig_w=0.1,
trigger_type=trigger_type,
label_type=label_type,
load_path=load_path)
num_dps_poisoned_dataset = len(poisoned_dataset)
for iii in range(num_dps_poisoned_dataset):
sample, label = poisoned_dataset[iii]
# (channel, height, width) = sample.shape #(c,h,w)
poisoned_dataset[iii] = (transforms_funcs(sample), label)
data['poison_' + mode] = DataLoader(
poisoned_dataset,
batch_size=ctx.dataloader.batch_size,
shuffle=False,
num_workers=ctx.dataloader.num_workers)
return data
def add_trans_normalize(data, ctx):
'''
data for each client is a dictionary.
'''
for key in data:
num_dataset = len(data[key].dataset)
mean, std = ctx.attack.mean, ctx.attack.std
if "CIFAR10" in ctx.data.type and key == MODE.TRAIN:
transforms_list = []
transforms_list.append(transforms.RandomHorizontalFlip())
transforms_list.append(transforms.ToTensor())
tran_train = transforms.Compose(transforms_list)
for iii in range(num_dataset):
sample = np.array(data[key].dataset[iii][0]).transpose(
1, 2, 0) * 255.0
sample = np.clip(sample.astype('uint8'), 0, 255)
sample = Image.fromarray(sample)
sample = tran_train(sample)
data[key].dataset[iii] = (normalize(sample, mean, std),
data[key].dataset[iii][1])
else:
for iii in range(num_dataset):
data[key].dataset[iii] = (normalize(data[key].dataset[iii][0],
mean, std),
data[key].dataset[iii][1])
return data
def select_poisoning(data, ctx, mode):
if 'edge' in ctx.attack.trigger_type:
data = load_poisoned_dataset_edgeset(data, ctx, mode)
elif 'semantic' in ctx.attack.trigger_type:
pass
else:
data = load_poisoned_dataset_pixel(data, ctx, mode)
return data
def poisoning(data, ctx):
for i in range(1, len(data) + 1):
if i == ctx.attack.attacker_id:
logger.info(50 * '-')
logger.info('start poisoning at Client: {}'.format(i))
logger.info(50 * '-')
data[i] = select_poisoning(data[i], ctx, mode=MODE.TRAIN)
data[i] = select_poisoning(data[i], ctx, mode=MODE.TEST)
if data[i].get(MODE.VAL):
data[i] = select_poisoning(data[i], ctx, mode=MODE.VAL)
data[i] = add_trans_normalize(data[i], ctx)
logger.info('finishing the clean and {} poisoning data processing \
for Client {:d}'.format(ctx.attack.trigger_type, i))