FS-TFP/federatedscope/nlp/hetero_tasks/dataloader/dataloader.py

278 lines
12 KiB
Python

import os
import logging
import copy
from tqdm import tqdm
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
logger = logging.getLogger(__name__)
def modified_cfg(cfg, cfg_client):
# create the temp dir for output
tmp_dir = os.path.join(cfg.outdir, 'temp')
os.makedirs(tmp_dir, exist_ok=True)
if cfg.federate.save_to:
cfg.federate.save_to = os.path.join(cfg.outdir, cfg.federate.save_to)
os.makedirs(cfg.federate.save_to, exist_ok=True)
# modification for debug (load a small subset from the whole dataset)
if cfg.data.is_debug:
cfg.federate.client_num = 6
cfg.data.hetero_data_name = [
'imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'
]
cfg.data.num_of_client_for_data = [1, 1, 1, 1, 1, 1]
cfg.federate.total_round_num = 2
cfg.train.local_update_steps = 2
# TODO
cfg.federate.atc_load_from = ''
cfg.federate.save_to = ''
cfg.data.cache_dir = ''
cfg.data.batch_size = 1
if cfg.model.stage == 'assign':
downstream_tasks = []
for group_id, num_clients in enumerate(
cfg.data.num_of_client_for_data):
downstream_tasks += [cfg.data.hetero_data_name[group_id]
] * num_clients
cfg.model.downstream_tasks = downstream_tasks
elif cfg.model.stage == 'contrast':
num_agg_topk = []
if len(cfg.aggregator.num_agg_topk) > 0:
for group_id, num_clients in enumerate(
cfg.data.num_of_client_for_data):
num_agg_topk += [cfg.aggregator.num_agg_topk[group_id]] * \
num_clients
else:
for group_id, num_clients in enumerate(
cfg.data.num_of_client_for_data):
num_agg_topk += [cfg.federate.client_num] * num_clients
cfg.aggregator.num_agg_topk = num_agg_topk
# client_config
if cfg_client is not None:
with open(os.path.join(cfg.outdir, 'config_client.yaml'),
'w') as outfile:
from contextlib import redirect_stdout
with redirect_stdout(outfile):
tmp_cfg = copy.deepcopy(cfg_client)
tmp_cfg.cfg_check_funcs = []
print(tmp_cfg.dump())
num_of_client_for_data = cfg.data.num_of_client_for_data
client_start_id = 1
for group_id, num_clients in enumerate(num_of_client_for_data):
group_cfg = cfg_client['client_group_{}'.format(group_id + 1)]
if cfg.data.is_debug:
if group_cfg.train.local_update_steps > 5:
group_cfg.train.local_update_steps = 5
group_cfg.data.batch_size = 1
for client_id in range(client_start_id,
client_start_id + num_clients):
cfg_client['client_{}'.format(client_id)] = group_cfg
client_start_id += num_clients
return cfg, cfg_client
def load_heteroNLP_data(config, client_cfgs):
from torch.utils.data import DataLoader
from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
from federatedscope.nlp.hetero_tasks.dataset.get_data import \
HeteroNLPDataLoader, SynthDataProcessor
from federatedscope.nlp.hetero_tasks.dataloader.datacollator import \
DataCollator
class HeteroNLPDataset(object):
ALL_DATA_NAME = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg']
def __init__(self, config, client_cfgs):
self.root = config.data.root
self.num_of_clients = config.data.num_of_client_for_data
self.data_name = list()
for each_data in config.data.hetero_data_name:
if each_data not in self.ALL_DATA_NAME:
logger.warning(f'We have not provided {each_data} in '
f'HeteroNLPDataset.')
else:
self.data_name.append(each_data)
data = HeteroNLPDataLoader(
data_dir=self.root,
data_name=self.data_name,
num_of_clients=self.num_of_clients).get_data()
self.processed_data = self._preprocess(config, client_cfgs, data)
def __getitem__(self, idx):
return self.processed_data[idx]
def __len__(self):
return len(self.processed_data)
def _preprocess(self, config, client_cfgs, data):
use_pretrain_task = config.model.stage == 'assign'
use_contrastive = config.model.stage == 'contrast'
tokenizer = setup_tokenizer(config.model.model_type)
data_collator = DataCollator(tokenizer=tokenizer) \
if use_pretrain_task else None
is_debug = config.data.is_debug # load a subset of data
processed_data = list()
for client_id in tqdm(range(1, config.federate.client_num + 1)):
applied_cfg = config if use_pretrain_task \
else client_cfgs['client_{}'.format(client_id)]
cur_task = applied_cfg.model.downstream_tasks[client_id - 1] \
if use_pretrain_task else applied_cfg.model.task
train_data, val_data, test_data = [
self._process_data(data=data[split][client_id - 1],
data_name=cur_task,
split=split,
tokenizer=tokenizer,
model_type=config.model.model_type,
cache_dir=config.data.cache_dir,
cfg=applied_cfg.data,
client_id=client_id,
pretrain=use_pretrain_task,
is_debug=is_debug)
for split in ['train', 'val', 'test']
]
dataloader = {}
dataloader['val'] = {
'dataloader': DataLoader(
dataset=val_data[0],
batch_size=applied_cfg.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
collate_fn=data_collator,
pin_memory=config.use_gpu),
'encoded': val_data[1],
'examples': val_data[2]
}
dataloader['test'] = {
'dataloader': DataLoader(
dataset=test_data[0],
batch_size=applied_cfg.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
collate_fn=data_collator,
pin_memory=config.use_gpu),
'encoded': test_data[1],
'examples': test_data[2]
}
if not use_contrastive:
dataloader['train'] = {
'dataloader': DataLoader(
dataset=train_data[0],
batch_size=applied_cfg.data.batch_size,
shuffle=config.data.shuffle,
num_workers=config.data.num_workers,
collate_fn=data_collator,
pin_memory=config.use_gpu),
'encoded': train_data[1],
'examples': train_data[2]
}
else:
dataloader['train_raw'] = {
'dataloader': DataLoader(
dataset=train_data[0],
batch_size=applied_cfg.data.batch_size,
shuffle=config.data.shuffle,
num_workers=config.data.num_workers,
collate_fn=data_collator,
pin_memory=config.use_gpu),
'encoded': train_data[1],
'examples': train_data[2]
}
dataloader['train_contrast'] = {
'dataloader': DataLoader(
dataset=train_data[0],
batch_size=applied_cfg.data.batch_size,
shuffle=False,
num_workers=config.data.num_workers,
collate_fn=data_collator,
pin_memory=config.use_gpu),
'encoded': train_data[1],
'examples': train_data[2]
}
processed_data.append(dataloader)
if use_contrastive:
logger.info(
'Preprocessing synthetic dataset for contrastive learning')
synth_data_processor = SynthDataProcessor(
config, processed_data)
synth_data_processor.save_data()
return processed_data
def _process_data(self, data, data_name, split, tokenizer, model_type,
cache_dir, cfg, client_id, pretrain, is_debug):
if data_name == 'imdb':
from federatedscope.nlp.hetero_tasks.dataset.imdb import \
process_imdb_dataset
process_func = process_imdb_dataset
elif data_name == 'agnews':
from federatedscope.nlp.hetero_tasks.dataset.agnews import \
process_agnews_dataset
process_func = process_agnews_dataset
elif data_name == 'squad':
from federatedscope.nlp.hetero_tasks.dataset.squad import \
process_squad_dataset
process_func = process_squad_dataset
elif data_name == 'newsqa':
from federatedscope.nlp.hetero_tasks.dataset.newsqa import \
process_newsqa_dataset
process_func = process_newsqa_dataset
elif data_name == 'cnndm':
from federatedscope.nlp.hetero_tasks.dataset.cnndm import \
process_cnndm_dataset
process_func = process_cnndm_dataset
elif data_name == 'msqg':
from federatedscope.nlp.hetero_tasks.dataset.msqg import \
process_msqg_dataset
process_func = process_msqg_dataset
else:
raise NotImplementedError(
f'Not process function is provided for {data_name}')
max_seq_len = getattr(cfg, 'max_seq_len', None)
max_query_len = getattr(cfg, 'max_query_len', None)
trunc_stride = getattr(cfg, 'trunc_stride', None)
max_tgt_len = getattr(cfg, 'max_tgt_len', None)
return process_func(data=data,
split=split,
tokenizer=tokenizer,
max_seq_len=max_seq_len,
max_query_len=max_query_len,
trunc_stride=trunc_stride,
max_src_len=max_seq_len,
max_tgt_len=max_tgt_len,
model_type=model_type,
cache_dir=cache_dir,
raw_cache_dir=cache_dir,
client_id=client_id,
pretrain=pretrain,
is_debug=is_debug)
modified_config, modified_client_cfgs = modified_cfg(config, client_cfgs)
dataset = HeteroNLPDataset(modified_config, modified_client_cfgs)
# Convert to dict
datadict = {
client_id + 1: dataset[client_id]
for client_id in range(len(dataset))
}
return datadict, modified_config