import math import torch import torch.nn as nn import torch.nn.functional as F import torch.backends.cudnn as cudnn import torchvision import torchvision.transforms as T import torchvision.transforms.functional as TF from torch.utils.data import DataLoader, Dataset from torchvision.datasets import CIFAR10, CIFAR100 import pickle as pkl import numpy as np from federatedscope.register import register_data from federatedscope.core.auxiliaries.splitter_builder import get_splitter class SimCLRTransform(): r""" Data Augmentations of SimCLR refer from https://github.com/akhilmathurs/orchestra/blob/main/utils.py Arguments: is_sup (bool): the transform for supervised learning or contrastive learning. :returns: torch.tensor: one output for supervised learning. :returns: torch.tensor: two output for contrastive learning torch.tensor: two output for contrastive learning """ def __init__(self, is_sup, image_size=32): self.transform = T.Compose([ T.RandomResizedCrop(image_size, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), T.RandomApply([T.ColorJitter(0.4, 0.4, 0.2, 0.1)], p=0.8), T.RandomGrayscale(p=0.2), T.RandomApply([T.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.5), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) self.mode = is_sup def __call__(self, x): if (self.mode): return self.transform(x) else: x1 = self.transform(x) x2 = self.transform(x) return x1, x2 def Cifar4CL(config): r""" generate Cifar10 Dataset transform and split dict for contrastive learning return { 'client_id': { 'train': DataLoader(), 'test': DataLoader(), 'val': DataLoader() } } """ transform_train = SimCLRTransform(is_sup=False, image_size=32) path = config.data.root data_train = CIFAR10(path, train=True, download=True, transform=transform_train) data_test = CIFAR10(path, train=False, download=True, transform=transform_train) # Split data into dict data_dict = dict() data_val = data_train data_dict = {'train': data_train, 'val': data_val, 'test': data_test} data_split_tuple = (data_dict.get('train'), data_dict.get('val'), data_dict.get('test')) config = config return data_split_tuple, config def Cifar4LP(config): r""" generate Cifar10 Dataset transform and split dict for linear prob evaluation of contrastive learning return { 'client_id': { 'train': DataLoader(), 'test': DataLoader(), 'val': DataLoader() } } """ transform_train = T.Compose([ T.RandomResizedCrop(32, scale=(0.5, 1.0), interpolation=T.InterpolationMode.BICUBIC), T.RandomHorizontalFlip(p=0.5), T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) ]) transform_test = T.Compose( [T.ToTensor(), T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) path = config.data.root data_train = CIFAR10(path, train=True, download=True, transform=transform_train) data_val = CIFAR10(path, train=True, download=True, transform=transform_test) data_test = CIFAR10(path, train=False, download=True, transform=transform_test) # Split data into dict data_dict = dict() data_val = data_train data_dict = {'train': data_train, 'val': data_val, 'test': data_test} data_split_tuple = (data_dict.get('train'), data_dict.get('val'), data_dict.get('test')) config = config return data_split_tuple, config def load_cifar_dataset(config): if config.data.type == "Cifar4CL": data, modified_config = Cifar4CL(config) return data, modified_config elif config.data.type == "Cifar4LP": data, modified_config = Cifar4LP(config) return data, modified_config