from socket import NI_NAMEREQD import torch import torch.utils.data as data from PIL import Image import numpy as np from torchvision.datasets import MNIST, EMNIST, CIFAR10 from torchvision.datasets import DatasetFolder from torchvision import transforms import os import sys import logging import pickle import copy logger = logging.getLogger(__name__) def create_ardis_poisoned_dataset(data_path, base_label=7, target_label=1, fraction=0.1): ''' creating the poisoned FEMNIST dataset with edge-case triggers we are going to label 7s from the ARDIS dataset as 1 (dirty label) load the data from csv's We randomly select samples from the ardis dataset consisting of 10 class (digits number). fraction: the fraction for sampled data. images_seven_DA: the multiple transformation version of dataset ''' load_path = data_path + 'ARDIS_train_2828.csv' ardis_images = np.loadtxt(load_path, dtype='float') load_path = data_path + 'ARDIS_train_labels.csv' ardis_labels = np.loadtxt(load_path, dtype='float') # reshape to be [samples][width][height] ardis_images = ardis_images.reshape(ardis_images.shape[0], 28, 28).astype('float32') # labels are one-hot encoded indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] images_seven = ardis_images[indices_seven, :] images_seven = torch.tensor(images_seven).type(torch.uint8) if fraction < 1: num_sampled_data_points = (int)(fraction * images_seven.size()[0]) perm = torch.randperm(images_seven.size()[0]) idx = perm[:num_sampled_data_points] images_seven_cut = images_seven[idx] images_seven_cut = images_seven_cut.unsqueeze(1) logger.info('size of images_seven_cut: ', images_seven_cut.size()) poisoned_labels_cut = (torch.zeros(images_seven_cut.size()[0]) + target_label).long() else: images_seven_DA = copy.deepcopy(images_seven) cand_angles = [180 / fraction * i for i in range(1, fraction + 1)] logger.info("Candidate angles for DA: {}".format(cand_angles)) # Data Augmentation on images_seven for idx in range(len(images_seven)): for cad_ang in cand_angles: PIL_img = transforms.ToPILImage()( images_seven[idx]).convert("L") PIL_img_rotate = transforms.functional.rotate(PIL_img, cad_ang, fill=(0, )) img_rotate = torch.from_numpy(np.array(PIL_img_rotate)) images_seven_DA = torch.cat( (images_seven_DA, img_rotate.reshape(1, img_rotate.size()[0], img_rotate.size()[0])), 0) logger.info(images_seven_DA.size()) poisoned_labels_DA = (torch.zeros(images_seven_DA.size()[0]) + target_label).long() poisoned_edgeset = [] if fraction < 1: for ii in range(len(images_seven_cut)): poisoned_edgeset.append( (images_seven_cut[ii], poisoned_labels_cut[ii])) else: for ii in range(len(images_seven_DA)): poisoned_edgeset.append( (images_seven_DA[ii], poisoned_labels_DA[ii])) return poisoned_edgeset def create_ardis_test_dataset(data_path, base_label=7, target_label=1): # load the data from csv's load_path = data_path + 'ARDIS_test_2828.csv' ardis_images = np.loadtxt(load_path, dtype='float') load_path = data_path + 'ARDIS_test_labels.csv' ardis_labels = np.loadtxt(load_path, dtype='float') # reshape to be [samples][height][width] ardis_images = torch.tensor( ardis_images.reshape(ardis_images.shape[0], 28, 28).astype('float32')).type(torch.uint8) indices_seven = np.where(ardis_labels[:, base_label] == 1)[0] images_seven = ardis_images[indices_seven, :] images_seven = torch.tensor(images_seven).type(torch.uint8) images_seven = images_seven.unsqueeze(1) poisoned_labels = (torch.zeros(images_seven.size()[0]) + target_label).long() poisoned_labels = torch.tensor(poisoned_labels) ardis_test_dataset = [] for ii in range(len(images_seven)): ardis_test_dataset.append((images_seven[ii], poisoned_labels[ii])) return ardis_test_dataset