367 lines
11 KiB
Python
367 lines
11 KiB
Python
import torch.utils.data as data
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torchvision
|
|
import torchvision.transforms as transforms
|
|
import os
|
|
import csv
|
|
import random
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
import time
|
|
# import cv2
|
|
import matplotlib
|
|
from matplotlib import image as mlt
|
|
|
|
|
|
def normalize(X, mean, std, device=None):
|
|
channel = X.shape[0]
|
|
mean = torch.tensor(mean).view(channel, 1, 1)
|
|
std = torch.tensor(std).view(channel, 1, 1)
|
|
return (X - mean) / std
|
|
|
|
|
|
def selectTrigger(img, height, width, distance, trig_h, trig_w, triggerType,
|
|
load_path):
|
|
'''
|
|
return the img: np.array [0:255], (height, width, channel)
|
|
'''
|
|
|
|
assert triggerType in [
|
|
'squareTrigger', 'gridTrigger', 'fourCornerTrigger',
|
|
'fourCorner_w_Trigger', 'randomPixelTrigger', 'signalTrigger',
|
|
'hkTrigger', 'sigTrigger', 'sig_n_Trigger', 'wanetTrigger',
|
|
'wanetTriggerCross'
|
|
]
|
|
|
|
if triggerType == 'squareTrigger':
|
|
img = _squareTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'gridTrigger':
|
|
img = _gridTriger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'fourCornerTrigger':
|
|
img = _fourCornerTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'fourCorner_w_Trigger':
|
|
img = _fourCorner_w_Trigger(img, height, width, distance, trig_h,
|
|
trig_w)
|
|
|
|
elif triggerType == 'randomPixelTrigger':
|
|
img = _randomPixelTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'signalTrigger':
|
|
img = _signalTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'hkTrigger':
|
|
img = _hkTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'sigTrigger':
|
|
img = _sigTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'sig_n_Trigger':
|
|
img = _sig_n_Trigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'wanetTrigger':
|
|
img = _wanetTrigger(img, height, width, distance, trig_h, trig_w)
|
|
|
|
elif triggerType == 'wanetTriggerCross':
|
|
img = _wanetTriggerCross(img, height, width, distance, trig_h, trig_w)
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
return img
|
|
|
|
|
|
def _squareTrigger(img, height, width, distance, trig_h, trig_w):
|
|
# white squares
|
|
for j in range(width - distance - trig_w, width - distance):
|
|
for k in range(height - distance - trig_h, height - distance):
|
|
img[j, k] = 255
|
|
|
|
return img
|
|
|
|
|
|
def _gridTriger(img, height, width, distance, trig_h, trig_w):
|
|
# right bottom
|
|
img[height - 1][width - 1] = 255
|
|
img[height - 1][width - 2] = 0
|
|
img[height - 1][width - 3] = 255
|
|
|
|
img[height - 2][width - 1] = 0
|
|
img[height - 2][width - 2] = 255
|
|
img[height - 2][width - 3] = 0
|
|
|
|
img[height - 3][width - 1] = 255
|
|
img[height - 3][width - 2] = 0
|
|
img[height - 3][width - 3] = 0
|
|
|
|
return img
|
|
|
|
|
|
def _fourCornerTrigger(img, height, width, distance, trig_h, trig_w):
|
|
# right bottom
|
|
img[height - 1][width - 1] = 255
|
|
img[height - 1][width - 2] = 0
|
|
img[height - 1][width - 3] = 255
|
|
|
|
img[height - 2][width - 1] = 0
|
|
img[height - 2][width - 2] = 255
|
|
img[height - 2][width - 3] = 0
|
|
|
|
img[height - 3][width - 1] = 255
|
|
img[height - 3][width - 2] = 0
|
|
img[height - 3][width - 3] = 0
|
|
|
|
# left top
|
|
img[1][1] = 255
|
|
img[1][2] = 0
|
|
img[1][3] = 255
|
|
|
|
img[2][1] = 0
|
|
img[2][2] = 255
|
|
img[2][3] = 0
|
|
|
|
img[3][1] = 255
|
|
img[3][2] = 0
|
|
img[3][3] = 0
|
|
|
|
# right top
|
|
img[height - 1][1] = 255
|
|
img[height - 1][2] = 0
|
|
img[height - 1][3] = 255
|
|
|
|
img[height - 2][1] = 0
|
|
img[height - 2][2] = 255
|
|
img[height - 2][3] = 0
|
|
|
|
img[height - 3][1] = 255
|
|
img[height - 3][2] = 0
|
|
img[height - 3][3] = 0
|
|
|
|
# left bottom
|
|
img[1][width - 1] = 255
|
|
img[2][width - 1] = 0
|
|
img[3][width - 1] = 255
|
|
|
|
img[1][width - 2] = 0
|
|
img[2][width - 2] = 255
|
|
img[3][width - 2] = 0
|
|
|
|
img[1][width - 3] = 255
|
|
img[2][width - 3] = 0
|
|
img[3][width - 3] = 0
|
|
|
|
return img
|
|
|
|
|
|
def _fourCorner_w_Trigger(img, height, width, distance, trig_h, trig_w):
|
|
# right bottom
|
|
img[height - 1][width - 1] = 255
|
|
img[height - 1][width - 2] = 255
|
|
img[height - 1][width - 3] = 255
|
|
|
|
img[height - 2][width - 1] = 255
|
|
img[height - 2][width - 2] = 255
|
|
img[height - 2][width - 3] = 255
|
|
|
|
img[height - 3][width - 1] = 255
|
|
img[height - 3][width - 2] = 255
|
|
img[height - 3][width - 3] = 255
|
|
|
|
# left top
|
|
img[1][1] = 255
|
|
img[1][2] = 255
|
|
img[1][3] = 255
|
|
|
|
img[2][1] = 255
|
|
img[2][2] = 255
|
|
img[2][3] = 255
|
|
|
|
img[3][1] = 255
|
|
img[3][2] = 255
|
|
img[3][3] = 255
|
|
|
|
# right top
|
|
img[height - 1][1] = 255
|
|
img[height - 1][2] = 255
|
|
img[height - 1][3] = 255
|
|
|
|
img[height - 2][1] = 255
|
|
img[height - 2][2] = 255
|
|
img[height - 2][3] = 255
|
|
|
|
img[height - 3][1] = 255
|
|
img[height - 3][2] = 255
|
|
img[height - 3][3] = 255
|
|
|
|
# left bottom
|
|
img[1][width - 1] = 255
|
|
img[2][width - 1] = 255
|
|
img[3][width - 1] = 255
|
|
|
|
img[1][width - 2] = 255
|
|
img[2][width - 2] = 255
|
|
img[3][width - 2] = 255
|
|
|
|
img[1][height - 3] = 255
|
|
img[2][height - 3] = 255
|
|
img[3][height - 3] = 255
|
|
|
|
return img
|
|
|
|
|
|
def _randomPixelTrigger(img, height, width, distance, trig_h, trig_w):
|
|
alpha = 0.2
|
|
mask = np.random.randint(low=0,
|
|
high=256,
|
|
size=(height, width),
|
|
dtype=np.uint8)
|
|
blend_img = (1 - alpha) * img + alpha * mask.reshape((height, width, 1))
|
|
blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
|
|
|
|
return blend_img
|
|
|
|
|
|
def _signalTrigger(img, height, width, distance, trig_h, trig_w, load_path):
|
|
# vertical stripe pattern different from sig
|
|
alpha = 0.2
|
|
# load signal mask
|
|
load_path = os.path.join(load_path, 'signal_cifar10_mask.npy')
|
|
signal_mask = np.load(load_path)
|
|
blend_img = (1 - alpha) * img + alpha * signal_mask.reshape(
|
|
(height, width, 1)) # FOR CIFAR10
|
|
blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
|
|
|
|
return blend_img
|
|
|
|
|
|
def _hkTrigger(img, height, width, distance, trig_h, trig_w, load_path):
|
|
# hello kitty pattern
|
|
alpha = 0.2
|
|
# load signal mask
|
|
load_path = os.path.join(load_path, 'hello_kitty.png')
|
|
signal_mask = mlt.imread(load_path) * 255
|
|
# signal_mask = cv2.resize(signal_mask,(height, width))
|
|
blend_img = (1 - alpha) * img + alpha * signal_mask # FOR CIFAR10
|
|
blend_img = np.clip(blend_img.astype('uint8'), 0, 255)
|
|
|
|
return blend_img
|
|
|
|
|
|
def _sigTrigger(img, height, width, distance, trig_h, trig_w, delta=20, f=6):
|
|
"""
|
|
Implement paper:
|
|
> Barni, M., Kallas, K., & Tondi, B. (2019).
|
|
> arXiv preprint arXiv:1902.11237
|
|
superimposed sinusoidal backdoor signal with default parameters
|
|
"""
|
|
delta = 20
|
|
img = np.float32(img)
|
|
pattern = np.zeros_like(img)
|
|
m = pattern.shape[1]
|
|
for i in range(int(img.shape[0])):
|
|
for j in range(int(img.shape[1])):
|
|
pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)
|
|
# img = (1-alpha) * np.uint32(img) + alpha * pattern
|
|
img = np.uint32(img) + pattern
|
|
img = np.uint8(np.clip(img, 0, 255))
|
|
return img
|
|
|
|
|
|
def _sig_n_Trigger(img,
|
|
height,
|
|
width,
|
|
distance,
|
|
trig_h,
|
|
trig_w,
|
|
delta=40,
|
|
f=6):
|
|
"""
|
|
Implement paper:
|
|
> Barni, M., Kallas, K., & Tondi, B. (2019).
|
|
> arXiv preprint arXiv:1902.11237
|
|
superimposed sinusoidal backdoor signal with default parameters
|
|
"""
|
|
# alpha = 0.2
|
|
delta = 10
|
|
img = np.float32(img)
|
|
pattern = np.zeros_like(img)
|
|
m = pattern.shape[1]
|
|
for i in range(int(img.shape[0])):
|
|
for j in range(int(img.shape[1])):
|
|
pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)
|
|
# img = (1-alpha) * np.uint32(img) + alpha * pattern
|
|
img = np.uint32(img) + pattern
|
|
img = np.uint8(np.clip(img, 0, 255))
|
|
return img
|
|
|
|
|
|
def _wanetTrigger(img, height, width, distance, trig_w, trig_h, delta=20, f=6):
|
|
"""
|
|
Implement paper:
|
|
> WaNet -- Imperceptible Warping-based Backdoor Attack
|
|
> Anh Nguyen, Anh Tran, ICLR 2021
|
|
> https://arxiv.org/abs/2102.10369
|
|
"""
|
|
k = 4
|
|
s = 0.5
|
|
input_height = height
|
|
grid_rescale = 1
|
|
ins = torch.rand(1, 2, k, k) * 2 - 1
|
|
ins = ins / torch.mean(torch.abs(ins))
|
|
noise_grid = (F.upsample(ins,
|
|
size=input_height,
|
|
mode="bicubic",
|
|
align_corners=True).permute(0, 2, 3, 1))
|
|
array1d = torch.linspace(-1, 1, steps=input_height)
|
|
x, y = torch.meshgrid(array1d, array1d)
|
|
# identity_grid = torch.stack((y, x), 2)[None, ...].to(device)
|
|
identity_grid = torch.stack((y, x), 2)[None, ...]
|
|
grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
|
|
grid_temps = torch.clamp(grid_temps, -1, 1)
|
|
img = np.float32(img)
|
|
img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0)
|
|
img = F.grid_sample(img, grid_temps,
|
|
align_corners=True).squeeze(0).reshape(
|
|
height, width, -1)
|
|
img = np.uint8(np.clip(img.cpu().numpy(), 0, 255))
|
|
|
|
return img
|
|
|
|
|
|
def _wanetTriggerCross(img, height, width, distance, trig_w, trig_h):
|
|
"""
|
|
Implement paper:
|
|
> WaNet -- Imperceptible Warping-based Backdoor Attack
|
|
> Anh Nguyen, Anh Tran, ICLR 2021
|
|
> https://arxiv.org/abs/2102.10369
|
|
"""
|
|
k = 4
|
|
s = 0.5
|
|
input_height = height
|
|
grid_rescale = 1
|
|
ins = torch.rand(1, 2, k, k) * 2 - 1
|
|
ins = ins / torch.mean(torch.abs(ins))
|
|
noise_grid = (F.upsample(ins,
|
|
size=input_height,
|
|
mode="bicubic",
|
|
align_corners=True).permute(0, 2, 3, 1))
|
|
array1d = torch.linspace(-1, 1, steps=input_height)
|
|
x, y = torch.meshgrid(array1d, array1d)
|
|
identity_grid = torch.stack((y, x), 2)[None, ...]
|
|
grid_temps = (identity_grid + s * noise_grid / input_height) * grid_rescale
|
|
grid_temps = torch.clamp(grid_temps, -1, 1)
|
|
ins = torch.rand(1, input_height, input_height, 2) * 2 - 1
|
|
grid_temps2 = grid_temps + ins / input_height
|
|
grid_temps2 = torch.clamp(grid_temps2, -1, 1)
|
|
img = np.float32(img)
|
|
img = torch.tensor(img).reshape(-1, height, width).unsqueeze(0)
|
|
img = F.grid_sample(img, grid_temps2,
|
|
align_corners=True).squeeze(0).reshape(
|
|
height, width, -1)
|
|
img = np.uint8(np.clip(img.cpu().numpy(), 0, 255))
|
|
return img
|