import torch from PIL import Image import numpy as np from torchvision import transforms from federatedscope.core.auxiliaries.transform_builder import get_transform from federatedscope.attack.auxiliary.backdoor_utils import selectTrigger from torch.utils.data import DataLoader from federatedscope.attack.auxiliary.backdoor_utils import normalize from federatedscope.core.trainers.enums import MODE import pickle import logging import os logger = logging.getLogger(__name__) def load_poisoned_dataset_edgeset(data, ctx, mode): transforms_funcs, _, _ = get_transform(ctx, 'torchvision')['transform'] load_path = ctx.attack.edge_path if "femnist" in ctx.data.type: if mode == MODE.TRAIN: train_path = os.path.join(load_path, "poisoned_edgeset_fraction_0.1") with open(train_path, "rb") as saved_data_file: poisoned_edgeset = torch.load(saved_data_file) num_dps_poisoned_dataset = len(poisoned_edgeset) for ii in range(num_dps_poisoned_dataset): sample, label = poisoned_edgeset[ii] # (channel, height, width) = sample.shape #(c,h,w) sample = sample.numpy().transpose(1, 2, 0) data[mode].dataset.append((transforms_funcs(sample), label)) if mode == MODE.TEST or mode == MODE.VAL: poison_testset = list() test_path = os.path.join(load_path, 'ardis_test_dataset.pt') with open(test_path) as saved_data_file: poisoned_edgeset = torch.load(saved_data_file) num_dps_poisoned_dataset = len(poisoned_edgeset) for ii in range(num_dps_poisoned_dataset): sample, label = poisoned_edgeset[ii] # (channel, height, width) = sample.shape #(c,h,w) sample = sample.numpy().transpose(1, 2, 0) poison_testset.append((transforms_funcs(sample), label)) data['poison_' + mode] = DataLoader( poison_testset, batch_size=ctx.dataloader.batch_size, shuffle=False, num_workers=ctx.dataloader.num_workers) elif "CIFAR10" in ctx.data.type: target_label = int(ctx.attack.target_label_ind) label = target_label num_poisoned = ctx.attack.edge_num if mode == MODE.TRAIN: train_path = os.path.join(load_path, 'southwest_images_new_train.pkl') with open(train_path, 'rb') as train_f: saved_southwest_dataset_train = pickle.load(train_f) num_poisoned_dataset = num_poisoned samped_poisoned_data_indices = np.random.choice( saved_southwest_dataset_train.shape[0], num_poisoned_dataset, replace=False) saved_southwest_dataset_train = saved_southwest_dataset_train[ samped_poisoned_data_indices, :, :, :] for ii in range(num_poisoned_dataset): sample = saved_southwest_dataset_train[ii] data[mode].dataset.append((transforms_funcs(sample), label)) logger.info('adding {:d} edge-cased samples in CIFAR-10'.format( num_poisoned)) if mode == MODE.TEST or mode == MODE.VAL: poison_testset = list() test_path = os.path.join(load_path, 'southwest_images_new_test.pkl') with open(test_path, 'rb') as test_f: saved_southwest_dataset_test = pickle.load(test_f) num_poisoned_dataset = len(saved_southwest_dataset_test) for ii in range(num_poisoned_dataset): sample = saved_southwest_dataset_test[ii] poison_testset.append((transforms_funcs(sample), label)) data['poison_' + mode] = DataLoader( poison_testset, batch_size=ctx.dataloader.batch_size, shuffle=False, num_workers=ctx.dataloader.num_workers) else: raise RuntimeError( 'Now, we only support the FEMNIST and CIFAR-10 datasets') return data def addTrigger(dataset, target_label, inject_portion, mode, distance, trig_h, trig_w, trigger_type, label_type, surrogate_model=None, load_path=None): height = dataset[0][0].shape[-2] width = dataset[0][0].shape[-1] trig_h = int(trig_h * height) trig_w = int(trig_w * width) if 'wanet' in trigger_type: cross_portion = 2 # default val following the original paper perm_then = np.random.permutation( len(dataset ))[0:int(len(dataset) * inject_portion * (1 + cross_portion))] perm = perm_then[0:int(len(dataset) * inject_portion)] perm_cross = perm_then[( int(len(dataset) * inject_portion) + 1):int(len(dataset) * inject_portion * (1 + cross_portion))] else: perm = np.random.permutation( len(dataset))[0:int(len(dataset) * inject_portion)] dataset_ = list() for i in range(len(dataset)): data = dataset[i] if label_type == 'dirty': # all2one attack if mode == MODE.TRAIN: img = np.array(data[0]).transpose(1, 2, 0) * 255.0 img = np.clip(img.astype('uint8'), 0, 255) height = img.shape[0] width = img.shape[1] if i in perm: img = selectTrigger(img, height, width, distance, trig_h, trig_w, trigger_type, load_path) dataset_.append((img, target_label)) elif 'wanet' in trigger_type and i in perm_cross: img = selectTrigger(img, width, height, distance, trig_w, trig_h, 'wanetTriggerCross', load_path) dataset_.append((img, data[1])) else: dataset_.append((img, data[1])) if mode == MODE.TEST or mode == MODE.VAL: if data[1] == target_label: continue img = np.array(data[0]).transpose(1, 2, 0) * 255.0 img = np.clip(img.astype('uint8'), 0, 255) height = img.shape[0] width = img.shape[1] if i in perm: img = selectTrigger(img, width, height, distance, trig_w, trig_h, trigger_type, load_path) dataset_.append((img, target_label)) else: dataset_.append((img, data[1])) elif label_type == 'clean_label': pass return dataset_ def load_poisoned_dataset_pixel(data, ctx, mode): trigger_type = ctx.attack.trigger_type label_type = ctx.attack.label_type target_label = int(ctx.attack.target_label_ind) transforms_funcs = get_transform(ctx, 'torchvision')[0]['transform'] if "femnist" in ctx.data.type or "CIFAR10" in ctx.data.type: inject_portion_train = ctx.attack.poison_ratio else: raise RuntimeError( 'Now, we only support the FEMNIST and CIFAR-10 datasets') inject_portion_test = 1.0 load_path = ctx.attack.trigger_path if mode == MODE.TRAIN: poisoned_dataset = addTrigger(data[mode].dataset, target_label, inject_portion_train, mode=mode, distance=1, trig_h=0.1, trig_w=0.1, trigger_type=trigger_type, label_type=label_type, load_path=load_path) num_dps_poisoned_dataset = len(poisoned_dataset) for iii in range(num_dps_poisoned_dataset): sample, label = poisoned_dataset[iii] poisoned_dataset[iii] = (transforms_funcs(sample), label) data[mode] = DataLoader(poisoned_dataset, batch_size=ctx.dataloader.batch_size, shuffle=True, num_workers=ctx.dataloader.num_workers) if mode == MODE.TEST or mode == MODE.VAL: poisoned_dataset = addTrigger(data[mode].dataset, target_label, inject_portion_test, mode=mode, distance=1, trig_h=0.1, trig_w=0.1, trigger_type=trigger_type, label_type=label_type, load_path=load_path) num_dps_poisoned_dataset = len(poisoned_dataset) for iii in range(num_dps_poisoned_dataset): sample, label = poisoned_dataset[iii] # (channel, height, width) = sample.shape #(c,h,w) poisoned_dataset[iii] = (transforms_funcs(sample), label) data['poison_' + mode] = DataLoader( poisoned_dataset, batch_size=ctx.dataloader.batch_size, shuffle=False, num_workers=ctx.dataloader.num_workers) return data def add_trans_normalize(data, ctx): ''' data for each client is a dictionary. ''' for key in data: num_dataset = len(data[key].dataset) mean, std = ctx.attack.mean, ctx.attack.std if "CIFAR10" in ctx.data.type and key == MODE.TRAIN: transforms_list = [] transforms_list.append(transforms.RandomHorizontalFlip()) transforms_list.append(transforms.ToTensor()) tran_train = transforms.Compose(transforms_list) for iii in range(num_dataset): sample = np.array(data[key].dataset[iii][0]).transpose( 1, 2, 0) * 255.0 sample = np.clip(sample.astype('uint8'), 0, 255) sample = Image.fromarray(sample) sample = tran_train(sample) data[key].dataset[iii] = (normalize(sample, mean, std), data[key].dataset[iii][1]) else: for iii in range(num_dataset): data[key].dataset[iii] = (normalize(data[key].dataset[iii][0], mean, std), data[key].dataset[iii][1]) return data def select_poisoning(data, ctx, mode): if 'edge' in ctx.attack.trigger_type: data = load_poisoned_dataset_edgeset(data, ctx, mode) elif 'semantic' in ctx.attack.trigger_type: pass else: data = load_poisoned_dataset_pixel(data, ctx, mode) return data def poisoning(data, ctx): for i in range(1, len(data) + 1): if i == ctx.attack.attacker_id: logger.info(50 * '-') logger.info('start poisoning at Client: {}'.format(i)) logger.info(50 * '-') data[i] = select_poisoning(data[i], ctx, mode=MODE.TRAIN) data[i] = select_poisoning(data[i], ctx, mode=MODE.TEST) if data[i].get(MODE.VAL): data[i] = select_poisoning(data[i], ctx, mode=MODE.VAL) data[i] = add_trans_normalize(data[i], ctx) logger.info('finishing the clean and {} poisoning data processing \ for Client {:d}'.format(ctx.attack.trigger_type, i))