import os import os.path as osp import logging import torch from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \ DatasetDict, NUM_DEBUG logger = logging.getLogger(__name__) def get_imdb_examples(data, is_debug=False): if is_debug: data = data[:NUM_DEBUG] examples = [] for ex in data: examples.append((ex['text'], ex['label'])) return examples def process_imdb_dataset(data, split, tokenizer, max_seq_len, cache_dir='', client_id=None, pretrain=False, is_debug=False, **kwargs): if pretrain: return process_imdb_dataset_for_pretrain(data, split, tokenizer, max_seq_len, cache_dir, client_id, is_debug) save_dir = osp.join(cache_dir, 'train', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): logger.info('Loading cache file from \'{}\''.format(cache_file)) cache_data = torch.load(cache_file) examples = cache_data['examples'] encoded_inputs = cache_data['encoded_inputs'] else: examples = get_imdb_examples(data, is_debug) texts = [ex[0] for ex in examples] encoded_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=max_seq_len, return_tensors='pt') if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) torch.save({ 'examples': examples, 'encoded_inputs': encoded_inputs }, cache_file) labels = [ex[1] for ex in examples] example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) dataset = DatasetDict({ 'token_ids': encoded_inputs.input_ids, 'token_type_ids': encoded_inputs.token_type_ids, 'attention_mask': encoded_inputs.attention_mask, 'labels': torch.LongTensor(labels), 'example_indices': example_indices }) return dataset, encoded_inputs, examples def process_imdb_dataset_for_pretrain(data, split, tokenizer, max_seq_len, cache_dir='', client_id=None, is_debug=False): save_dir = osp.join(cache_dir, 'pretrain', str(client_id)) cache_file = osp.join(save_dir, split + '.pt') if osp.exists(cache_file): logger.info('Loading cache file from \'{}\''.format(cache_file)) cache_data = torch.load(cache_file) examples = cache_data['examples'] encoded_inputs = cache_data['encoded_inputs'] else: examples = get_imdb_examples(data, is_debug) texts = [ex[0] for ex in examples] texts = split_sent(texts, eoq=tokenizer.eoq_token) encoded_inputs = tokenizer(texts, padding='max_length', truncation=True, max_length=max_seq_len, return_tensors='pt') num_non_padding = (encoded_inputs.input_ids != tokenizer.pad_token_id).sum(dim=-1) for i, pad_idx in enumerate(num_non_padding): encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id if cache_dir: logger.info('Saving cache file to \'{}\''.format(cache_file)) os.makedirs(save_dir, exist_ok=True) torch.save({ 'examples': examples, 'encoded_inputs': encoded_inputs }, cache_file) example_indices = torch.arange(encoded_inputs.input_ids.size(0), dtype=torch.long) dataset = DatasetDict({ 'token_ids': encoded_inputs.input_ids, 'attention_mask': encoded_inputs.attention_mask, 'example_indices': example_indices }) return dataset, encoded_inputs, examples