FS-TFP/federatedscope/cl/dataloader/Cifar10.py

149 lines
4.6 KiB
Python

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as T
import torchvision.transforms.functional as TF
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import CIFAR10, CIFAR100
import pickle as pkl
import numpy as np
from federatedscope.register import register_data
from federatedscope.core.auxiliaries.splitter_builder import get_splitter
class SimCLRTransform():
r"""
Data Augmentations of SimCLR refer from
https://github.com/akhilmathurs/orchestra/blob/main/utils.py
Arguments:
is_sup (bool): the transform for supervised learning
or contrastive learning.
:returns:
torch.tensor: one output for supervised learning.
:returns:
torch.tensor: two output for contrastive learning
torch.tensor: two output for contrastive learning
"""
def __init__(self, is_sup, image_size=32):
self.transform = T.Compose([
T.RandomResizedCrop(image_size,
scale=(0.5, 1.0),
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(p=0.5),
T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8),
T.RandomGrayscale(p=0.2),
T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))],
p=0.5),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.mode = is_sup
def __call__(self, x):
if (self.mode):
return self.transform(x)
else:
x1 = self.transform(x)
x2 = self.transform(x)
return x1, x2
def Cifar4CL(config):
r"""
generate Cifar10 Dataset transform and split dict for contrastive learning
return {
'client_id': {
'train': DataLoader(),
'test': DataLoader(),
'val': DataLoader()
}
}
"""
transform_train = SimCLRTransform(is_sup=False, image_size=32)
path = config.data.root
data_train = CIFAR10(path,
train=True,
download=True,
transform=transform_train)
data_test = CIFAR10(path,
train=False,
download=True,
transform=transform_train)
# Split data into dict
data_dict = dict()
data_val = data_train
data_dict = {'train': data_train, 'val': data_val, 'test': data_test}
data_split_tuple = (data_dict.get('train'), data_dict.get('val'),
data_dict.get('test'))
config = config
return data_split_tuple, config
def Cifar4LP(config):
r"""
generate Cifar10 Dataset transform and split dict for linear prob
evaluation of contrastive learning
return {
'client_id': {
'train': DataLoader(),
'test': DataLoader(),
'val': DataLoader()
}
}
"""
transform_train = T.Compose([
T.RandomResizedCrop(32,
scale=(0.5, 1.0),
interpolation=T.InterpolationMode.BICUBIC),
T.RandomHorizontalFlip(p=0.5),
T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
transform_test = T.Compose(
[T.ToTensor(),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
path = config.data.root
data_train = CIFAR10(path,
train=True,
download=True,
transform=transform_train)
data_val = CIFAR10(path,
train=True,
download=True,
transform=transform_test)
data_test = CIFAR10(path,
train=False,
download=True,
transform=transform_test)
# Split data into dict
data_dict = dict()
data_val = data_train
data_dict = {'train': data_train, 'val': data_val, 'test': data_test}
data_split_tuple = (data_dict.get('train'), data_dict.get('val'),
data_dict.get('test'))
config = config
return data_split_tuple, config
def load_cifar_dataset(config):
if config.data.type == "Cifar4CL":
data, modified_config = Cifar4CL(config)
return data, modified_config
elif config.data.type == "Cifar4LP":
data, modified_config = Cifar4LP(config)
return data, modified_config