278 lines
12 KiB
Python
278 lines
12 KiB
Python
import os
|
|
import logging
|
|
import copy
|
|
from tqdm import tqdm
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def modified_cfg(cfg, cfg_client):
|
|
|
|
# create the temp dir for output
|
|
tmp_dir = os.path.join(cfg.outdir, 'temp')
|
|
os.makedirs(tmp_dir, exist_ok=True)
|
|
|
|
if cfg.federate.save_to:
|
|
cfg.federate.save_to = os.path.join(cfg.outdir, cfg.federate.save_to)
|
|
os.makedirs(cfg.federate.save_to, exist_ok=True)
|
|
|
|
# modification for debug (load a small subset from the whole dataset)
|
|
if cfg.data.is_debug:
|
|
cfg.federate.client_num = 6
|
|
cfg.data.hetero_data_name = [
|
|
'imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg'
|
|
]
|
|
cfg.data.num_of_client_for_data = [1, 1, 1, 1, 1, 1]
|
|
cfg.federate.total_round_num = 2
|
|
cfg.train.local_update_steps = 2
|
|
# TODO
|
|
cfg.federate.atc_load_from = ''
|
|
cfg.federate.save_to = ''
|
|
cfg.data.cache_dir = ''
|
|
cfg.data.batch_size = 1
|
|
|
|
if cfg.model.stage == 'assign':
|
|
downstream_tasks = []
|
|
for group_id, num_clients in enumerate(
|
|
cfg.data.num_of_client_for_data):
|
|
downstream_tasks += [cfg.data.hetero_data_name[group_id]
|
|
] * num_clients
|
|
cfg.model.downstream_tasks = downstream_tasks
|
|
elif cfg.model.stage == 'contrast':
|
|
num_agg_topk = []
|
|
if len(cfg.aggregator.num_agg_topk) > 0:
|
|
for group_id, num_clients in enumerate(
|
|
cfg.data.num_of_client_for_data):
|
|
num_agg_topk += [cfg.aggregator.num_agg_topk[group_id]] * \
|
|
num_clients
|
|
else:
|
|
for group_id, num_clients in enumerate(
|
|
cfg.data.num_of_client_for_data):
|
|
num_agg_topk += [cfg.federate.client_num] * num_clients
|
|
cfg.aggregator.num_agg_topk = num_agg_topk
|
|
|
|
# client_config
|
|
if cfg_client is not None:
|
|
with open(os.path.join(cfg.outdir, 'config_client.yaml'),
|
|
'w') as outfile:
|
|
from contextlib import redirect_stdout
|
|
with redirect_stdout(outfile):
|
|
tmp_cfg = copy.deepcopy(cfg_client)
|
|
tmp_cfg.cfg_check_funcs = []
|
|
print(tmp_cfg.dump())
|
|
|
|
num_of_client_for_data = cfg.data.num_of_client_for_data
|
|
client_start_id = 1
|
|
for group_id, num_clients in enumerate(num_of_client_for_data):
|
|
group_cfg = cfg_client['client_group_{}'.format(group_id + 1)]
|
|
if cfg.data.is_debug:
|
|
if group_cfg.train.local_update_steps > 5:
|
|
group_cfg.train.local_update_steps = 5
|
|
group_cfg.data.batch_size = 1
|
|
for client_id in range(client_start_id,
|
|
client_start_id + num_clients):
|
|
cfg_client['client_{}'.format(client_id)] = group_cfg
|
|
client_start_id += num_clients
|
|
|
|
return cfg, cfg_client
|
|
|
|
|
|
def load_heteroNLP_data(config, client_cfgs):
|
|
|
|
from torch.utils.data import DataLoader
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import setup_tokenizer
|
|
from federatedscope.nlp.hetero_tasks.dataset.get_data import \
|
|
HeteroNLPDataLoader, SynthDataProcessor
|
|
from federatedscope.nlp.hetero_tasks.dataloader.datacollator import \
|
|
DataCollator
|
|
|
|
class HeteroNLPDataset(object):
|
|
ALL_DATA_NAME = ['imdb', 'agnews', 'squad', 'newsqa', 'cnndm', 'msqg']
|
|
|
|
def __init__(self, config, client_cfgs):
|
|
|
|
self.root = config.data.root
|
|
self.num_of_clients = config.data.num_of_client_for_data
|
|
self.data_name = list()
|
|
for each_data in config.data.hetero_data_name:
|
|
if each_data not in self.ALL_DATA_NAME:
|
|
logger.warning(f'We have not provided {each_data} in '
|
|
f'HeteroNLPDataset.')
|
|
else:
|
|
self.data_name.append(each_data)
|
|
|
|
data = HeteroNLPDataLoader(
|
|
data_dir=self.root,
|
|
data_name=self.data_name,
|
|
num_of_clients=self.num_of_clients).get_data()
|
|
|
|
self.processed_data = self._preprocess(config, client_cfgs, data)
|
|
|
|
def __getitem__(self, idx):
|
|
return self.processed_data[idx]
|
|
|
|
def __len__(self):
|
|
return len(self.processed_data)
|
|
|
|
def _preprocess(self, config, client_cfgs, data):
|
|
|
|
use_pretrain_task = config.model.stage == 'assign'
|
|
use_contrastive = config.model.stage == 'contrast'
|
|
tokenizer = setup_tokenizer(config.model.model_type)
|
|
data_collator = DataCollator(tokenizer=tokenizer) \
|
|
if use_pretrain_task else None
|
|
is_debug = config.data.is_debug # load a subset of data
|
|
|
|
processed_data = list()
|
|
for client_id in tqdm(range(1, config.federate.client_num + 1)):
|
|
applied_cfg = config if use_pretrain_task \
|
|
else client_cfgs['client_{}'.format(client_id)]
|
|
|
|
cur_task = applied_cfg.model.downstream_tasks[client_id - 1] \
|
|
if use_pretrain_task else applied_cfg.model.task
|
|
|
|
train_data, val_data, test_data = [
|
|
self._process_data(data=data[split][client_id - 1],
|
|
data_name=cur_task,
|
|
split=split,
|
|
tokenizer=tokenizer,
|
|
model_type=config.model.model_type,
|
|
cache_dir=config.data.cache_dir,
|
|
cfg=applied_cfg.data,
|
|
client_id=client_id,
|
|
pretrain=use_pretrain_task,
|
|
is_debug=is_debug)
|
|
for split in ['train', 'val', 'test']
|
|
]
|
|
|
|
dataloader = {}
|
|
dataloader['val'] = {
|
|
'dataloader': DataLoader(
|
|
dataset=val_data[0],
|
|
batch_size=applied_cfg.data.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.data.num_workers,
|
|
collate_fn=data_collator,
|
|
pin_memory=config.use_gpu),
|
|
'encoded': val_data[1],
|
|
'examples': val_data[2]
|
|
}
|
|
dataloader['test'] = {
|
|
'dataloader': DataLoader(
|
|
dataset=test_data[0],
|
|
batch_size=applied_cfg.data.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.data.num_workers,
|
|
collate_fn=data_collator,
|
|
pin_memory=config.use_gpu),
|
|
'encoded': test_data[1],
|
|
'examples': test_data[2]
|
|
}
|
|
|
|
if not use_contrastive:
|
|
dataloader['train'] = {
|
|
'dataloader': DataLoader(
|
|
dataset=train_data[0],
|
|
batch_size=applied_cfg.data.batch_size,
|
|
shuffle=config.data.shuffle,
|
|
num_workers=config.data.num_workers,
|
|
collate_fn=data_collator,
|
|
pin_memory=config.use_gpu),
|
|
'encoded': train_data[1],
|
|
'examples': train_data[2]
|
|
}
|
|
else:
|
|
dataloader['train_raw'] = {
|
|
'dataloader': DataLoader(
|
|
dataset=train_data[0],
|
|
batch_size=applied_cfg.data.batch_size,
|
|
shuffle=config.data.shuffle,
|
|
num_workers=config.data.num_workers,
|
|
collate_fn=data_collator,
|
|
pin_memory=config.use_gpu),
|
|
'encoded': train_data[1],
|
|
'examples': train_data[2]
|
|
}
|
|
dataloader['train_contrast'] = {
|
|
'dataloader': DataLoader(
|
|
dataset=train_data[0],
|
|
batch_size=applied_cfg.data.batch_size,
|
|
shuffle=False,
|
|
num_workers=config.data.num_workers,
|
|
collate_fn=data_collator,
|
|
pin_memory=config.use_gpu),
|
|
'encoded': train_data[1],
|
|
'examples': train_data[2]
|
|
}
|
|
processed_data.append(dataloader)
|
|
|
|
if use_contrastive:
|
|
logger.info(
|
|
'Preprocessing synthetic dataset for contrastive learning')
|
|
synth_data_processor = SynthDataProcessor(
|
|
config, processed_data)
|
|
synth_data_processor.save_data()
|
|
|
|
return processed_data
|
|
|
|
def _process_data(self, data, data_name, split, tokenizer, model_type,
|
|
cache_dir, cfg, client_id, pretrain, is_debug):
|
|
if data_name == 'imdb':
|
|
from federatedscope.nlp.hetero_tasks.dataset.imdb import \
|
|
process_imdb_dataset
|
|
process_func = process_imdb_dataset
|
|
elif data_name == 'agnews':
|
|
from federatedscope.nlp.hetero_tasks.dataset.agnews import \
|
|
process_agnews_dataset
|
|
process_func = process_agnews_dataset
|
|
elif data_name == 'squad':
|
|
from federatedscope.nlp.hetero_tasks.dataset.squad import \
|
|
process_squad_dataset
|
|
process_func = process_squad_dataset
|
|
elif data_name == 'newsqa':
|
|
from federatedscope.nlp.hetero_tasks.dataset.newsqa import \
|
|
process_newsqa_dataset
|
|
process_func = process_newsqa_dataset
|
|
elif data_name == 'cnndm':
|
|
from federatedscope.nlp.hetero_tasks.dataset.cnndm import \
|
|
process_cnndm_dataset
|
|
process_func = process_cnndm_dataset
|
|
elif data_name == 'msqg':
|
|
from federatedscope.nlp.hetero_tasks.dataset.msqg import \
|
|
process_msqg_dataset
|
|
process_func = process_msqg_dataset
|
|
else:
|
|
raise NotImplementedError(
|
|
f'Not process function is provided for {data_name}')
|
|
|
|
max_seq_len = getattr(cfg, 'max_seq_len', None)
|
|
max_query_len = getattr(cfg, 'max_query_len', None)
|
|
trunc_stride = getattr(cfg, 'trunc_stride', None)
|
|
max_tgt_len = getattr(cfg, 'max_tgt_len', None)
|
|
|
|
return process_func(data=data,
|
|
split=split,
|
|
tokenizer=tokenizer,
|
|
max_seq_len=max_seq_len,
|
|
max_query_len=max_query_len,
|
|
trunc_stride=trunc_stride,
|
|
max_src_len=max_seq_len,
|
|
max_tgt_len=max_tgt_len,
|
|
model_type=model_type,
|
|
cache_dir=cache_dir,
|
|
raw_cache_dir=cache_dir,
|
|
client_id=client_id,
|
|
pretrain=pretrain,
|
|
is_debug=is_debug)
|
|
|
|
modified_config, modified_client_cfgs = modified_cfg(config, client_cfgs)
|
|
dataset = HeteroNLPDataset(modified_config, modified_client_cfgs)
|
|
# Convert to dict
|
|
datadict = {
|
|
client_id + 1: dataset[client_id]
|
|
for client_id in range(len(dataset))
|
|
}
|
|
|
|
return datadict, modified_config
|