FS-TFP/federatedscope/nlp/hetero_tasks/dataset/imdb.py

116 lines
4.5 KiB
Python

import os
import os.path as osp
import logging
import torch
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
DatasetDict, NUM_DEBUG
logger = logging.getLogger(__name__)
def get_imdb_examples(data, is_debug=False):
if is_debug:
data = data[:NUM_DEBUG]
examples = []
for ex in data:
examples.append((ex['text'], ex['label']))
return examples
def process_imdb_dataset(data,
split,
tokenizer,
max_seq_len,
cache_dir='',
client_id=None,
pretrain=False,
is_debug=False,
**kwargs):
if pretrain:
return process_imdb_dataset_for_pretrain(data, split, tokenizer,
max_seq_len, cache_dir,
client_id, is_debug)
save_dir = osp.join(cache_dir, 'train', str(client_id))
cache_file = osp.join(save_dir, split + '.pt')
if osp.exists(cache_file):
logger.info('Loading cache file from \'{}\''.format(cache_file))
cache_data = torch.load(cache_file)
examples = cache_data['examples']
encoded_inputs = cache_data['encoded_inputs']
else:
examples = get_imdb_examples(data, is_debug)
texts = [ex[0] for ex in examples]
encoded_inputs = tokenizer(texts,
padding='max_length',
truncation=True,
max_length=max_seq_len,
return_tensors='pt')
if cache_dir:
logger.info('Saving cache file to \'{}\''.format(cache_file))
os.makedirs(save_dir, exist_ok=True)
torch.save({
'examples': examples,
'encoded_inputs': encoded_inputs
}, cache_file)
labels = [ex[1] for ex in examples]
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
dtype=torch.long)
dataset = DatasetDict({
'token_ids': encoded_inputs.input_ids,
'token_type_ids': encoded_inputs.token_type_ids,
'attention_mask': encoded_inputs.attention_mask,
'labels': torch.LongTensor(labels),
'example_indices': example_indices
})
return dataset, encoded_inputs, examples
def process_imdb_dataset_for_pretrain(data,
split,
tokenizer,
max_seq_len,
cache_dir='',
client_id=None,
is_debug=False):
save_dir = osp.join(cache_dir, 'pretrain', str(client_id))
cache_file = osp.join(save_dir, split + '.pt')
if osp.exists(cache_file):
logger.info('Loading cache file from \'{}\''.format(cache_file))
cache_data = torch.load(cache_file)
examples = cache_data['examples']
encoded_inputs = cache_data['encoded_inputs']
else:
examples = get_imdb_examples(data, is_debug)
texts = [ex[0] for ex in examples]
texts = split_sent(texts, eoq=tokenizer.eoq_token)
encoded_inputs = tokenizer(texts,
padding='max_length',
truncation=True,
max_length=max_seq_len,
return_tensors='pt')
num_non_padding = (encoded_inputs.input_ids !=
tokenizer.pad_token_id).sum(dim=-1)
for i, pad_idx in enumerate(num_non_padding):
encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id
encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
if cache_dir:
logger.info('Saving cache file to \'{}\''.format(cache_file))
os.makedirs(save_dir, exist_ok=True)
torch.save({
'examples': examples,
'encoded_inputs': encoded_inputs
}, cache_file)
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
dtype=torch.long)
dataset = DatasetDict({
'token_ids': encoded_inputs.input_ids,
'attention_mask': encoded_inputs.attention_mask,
'example_indices': example_indices
})
return dataset, encoded_inputs, examples