149 lines
4.6 KiB
Python
149 lines
4.6 KiB
Python
import math
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.backends.cudnn as cudnn
|
|
import torchvision
|
|
import torchvision.transforms as T
|
|
import torchvision.transforms.functional as TF
|
|
from torch.utils.data import DataLoader, Dataset
|
|
from torchvision.datasets import CIFAR10, CIFAR100
|
|
import pickle as pkl
|
|
import numpy as np
|
|
from federatedscope.register import register_data
|
|
from federatedscope.core.auxiliaries.splitter_builder import get_splitter
|
|
|
|
|
|
class SimCLRTransform():
|
|
r"""
|
|
Data Augmentations of SimCLR refer from
|
|
https://github.com/akhilmathurs/orchestra/blob/main/utils.py
|
|
Arguments:
|
|
is_sup (bool): the transform for supervised learning
|
|
or contrastive learning.
|
|
:returns:
|
|
torch.tensor: one output for supervised learning.
|
|
:returns:
|
|
torch.tensor: two output for contrastive learning
|
|
torch.tensor: two output for contrastive learning
|
|
"""
|
|
def __init__(self, is_sup, image_size=32):
|
|
self.transform = T.Compose([
|
|
T.RandomResizedCrop(image_size,
|
|
scale=(0.5, 1.0),
|
|
interpolation=T.InterpolationMode.BICUBIC),
|
|
T.RandomHorizontalFlip(p=0.5),
|
|
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
|
|
T.RandomGrayscale(p=0.2),
|
|
T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))],
|
|
p=0.5),
|
|
T.ToTensor(),
|
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
|
|
self.mode = is_sup
|
|
|
|
def __call__(self, x):
|
|
if (self.mode):
|
|
return self.transform(x)
|
|
else:
|
|
x1 = self.transform(x)
|
|
x2 = self.transform(x)
|
|
return x1, x2
|
|
|
|
|
|
def Cifar4CL(config):
|
|
r"""
|
|
generate Cifar10 Dataset transform and split dict for contrastive learning
|
|
return {
|
|
'client_id': {
|
|
'train': DataLoader(),
|
|
'test': DataLoader(),
|
|
'val': DataLoader()
|
|
}
|
|
}
|
|
"""
|
|
transform_train = SimCLRTransform(is_sup=False, image_size=32)
|
|
|
|
path = config.data.root
|
|
|
|
data_train = CIFAR10(path,
|
|
train=True,
|
|
download=True,
|
|
transform=transform_train)
|
|
data_test = CIFAR10(path,
|
|
train=False,
|
|
download=True,
|
|
transform=transform_train)
|
|
|
|
# Split data into dict
|
|
data_dict = dict()
|
|
data_val = data_train
|
|
|
|
data_dict = {'train': data_train, 'val': data_val, 'test': data_test}
|
|
data_split_tuple = (data_dict.get('train'), data_dict.get('val'),
|
|
data_dict.get('test'))
|
|
|
|
config = config
|
|
return data_split_tuple, config
|
|
|
|
|
|
def Cifar4LP(config):
|
|
r"""
|
|
generate Cifar10 Dataset transform and split dict for linear prob
|
|
evaluation of contrastive learning
|
|
return {
|
|
'client_id': {
|
|
'train': DataLoader(),
|
|
'test': DataLoader(),
|
|
'val': DataLoader()
|
|
}
|
|
}
|
|
"""
|
|
transform_train = T.Compose([
|
|
T.RandomResizedCrop(32,
|
|
scale=(0.5, 1.0),
|
|
interpolation=T.InterpolationMode.BICUBIC),
|
|
T.RandomHorizontalFlip(p=0.5),
|
|
T.ToTensor(),
|
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
|
])
|
|
transform_test = T.Compose(
|
|
[T.ToTensor(),
|
|
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
|
|
|
|
path = config.data.root
|
|
|
|
data_train = CIFAR10(path,
|
|
train=True,
|
|
download=True,
|
|
transform=transform_train)
|
|
data_val = CIFAR10(path,
|
|
train=True,
|
|
download=True,
|
|
transform=transform_test)
|
|
data_test = CIFAR10(path,
|
|
train=False,
|
|
download=True,
|
|
transform=transform_test)
|
|
|
|
# Split data into dict
|
|
data_dict = dict()
|
|
data_val = data_train
|
|
|
|
data_dict = {'train': data_train, 'val': data_val, 'test': data_test}
|
|
data_split_tuple = (data_dict.get('train'), data_dict.get('val'),
|
|
data_dict.get('test'))
|
|
|
|
config = config
|
|
return data_split_tuple, config
|
|
|
|
|
|
def load_cifar_dataset(config):
|
|
if config.data.type == "Cifar4CL":
|
|
data, modified_config = Cifar4CL(config)
|
|
return data, modified_config
|
|
elif config.data.type == "Cifar4LP":
|
|
data, modified_config = Cifar4LP(config)
|
|
return data, modified_config
|