FS-TFP/federatedscope/cv/dataset/leaf_cv.py

179 lines
6.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import random
import json
import torch
import math
import numpy as np
import os.path as osp
from PIL import Image
from tqdm import tqdm
from sklearn.model_selection import train_test_split
from federatedscope.core.data.utils import save_local_data, download_url
from federatedscope.cv.dataset.leaf import LEAF
IMAGE_SIZE = {'femnist': (28, 28), 'celeba': (84, 84, 3)}
MODE = {'femnist': 'L', 'celeba': 'RGB'}
class LEAF_CV(LEAF):
"""
LEAF CV dataset from "LEAF: A Benchmark for Federated Settings"
leaf.cmu.edu
Arguments:
root (str): root path.
name (str): name of dataset, femnist or celeba.
s_frac (float): fraction of the dataset to be used; default=0.3.
tr_frac (float): train set proportion for each task; default=0.8.
val_frac (float): valid set proportion for each task; default=0.0.
train_tasks_frac (float): fraction of test tasks; default=1.0.
transform: transform for x.
target_transform: transform for y.
"""
def __init__(self,
root,
name,
s_frac=0.3,
tr_frac=0.8,
val_frac=0.0,
train_tasks_frac=1.0,
seed=123,
transform=None,
target_transform=None):
self.s_frac = s_frac
self.tr_frac = tr_frac
self.val_frac = val_frac
self.seed = seed
self.train_tasks_frac = train_tasks_frac
super(LEAF_CV, self).__init__(root, name, transform, target_transform)
files = os.listdir(self.processed_dir)
files = [f for f in files if f.startswith('task_')]
if len(files):
# Sort by idx
files.sort(key=lambda k: int(k[5:]))
for file in files:
train_data, train_targets = torch.load(
osp.join(self.processed_dir, file, 'train.pt'))
test_data, test_targets = torch.load(
osp.join(self.processed_dir, file, 'test.pt'))
self.data_dict[int(file[5:])] = {
'train': (train_data, train_targets),
'test': (test_data, test_targets)
}
if osp.exists(osp.join(self.processed_dir, file, 'val.pt')):
val_data, val_targets = torch.load(
osp.join(self.processed_dir, file, 'val.pt'))
self.data_dict[int(file[5:])]['val'] = (val_data,
val_targets)
else:
raise RuntimeError(
'Please delete processed folder and try again!')
@property
def raw_file_names(self):
names = [f'{self.name}_all_data.zip']
return names
def download(self):
# Download to `self.raw_dir`.
url = 'https://federatedscope.oss-cn-beijing.aliyuncs.com'
os.makedirs(self.raw_dir, exist_ok=True)
for name in self.raw_file_names:
download_url(f'{url}/{name}', self.raw_dir)
def __getitem__(self, index):
"""
Arguments:
index (int): Index
:returns:
dict: {'train':[(image, target)],
'test':[(image, target)],
'val':[(image, target)]}
where target is the target class.
"""
img_dict = {}
data = self.data_dict[index]
for key in data:
img_dict[key] = []
imgs, targets = data[key]
for idx in range(targets.shape[0]):
img = np.resize(imgs[idx].numpy().astype(np.uint8),
IMAGE_SIZE[self.name])
img = Image.fromarray(img, mode=MODE[self.name])
if self.transform is not None:
img = self.transform(img)
if self.target_transform is not None:
targets[idx] = self.target_transform(targets[idx])
img_dict[key].append((img, targets[idx]))
return img_dict
def process(self):
raw_path = osp.join(self.raw_dir, "all_data")
files = os.listdir(raw_path)
files = [f for f in files if f.endswith('.json')]
n_tasks = math.ceil(len(files) * self.s_frac)
random.shuffle(files)
files = files[:n_tasks]
print("Preprocess data (Please leave enough space)...")
idx = 0
for num, file in enumerate(tqdm(files)):
with open(osp.join(raw_path, file), 'r') as f:
raw_data = json.load(f)
# Numpy to Tensor
for writer, v in raw_data['user_data'].items():
data, targets = v['x'], v['y']
if len(v['x']) > 2:
data = torch.tensor(np.stack(data))
targets = torch.LongTensor(np.stack(targets))
else:
data = torch.tensor(data)
targets = torch.LongTensor(targets)
train_data, test_data, train_targets, test_targets =\
train_test_split(
data,
targets,
train_size=self.tr_frac,
random_state=self.seed
)
if self.val_frac > 0:
val_data, test_data, val_targets, test_targets = \
train_test_split(
test_data,
test_targets,
train_size=self.val_frac / (1.-self.tr_frac),
random_state=self.seed
)
else:
val_data, val_targets = None, None
save_path = osp.join(self.processed_dir, f"task_{idx}")
os.makedirs(save_path, exist_ok=True)
save_local_data(dir_path=save_path,
train_data=train_data,
train_targets=train_targets,
test_data=test_data,
test_targets=test_targets,
val_data=val_data,
val_targets=val_targets)
idx += 1