117 lines
4.5 KiB
Python
117 lines
4.5 KiB
Python
import os
|
|
import os.path as osp
|
|
import logging
|
|
import torch
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
|
|
DatasetDict, NUM_DEBUG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_agnews_examples(data, is_debug=False):
|
|
if is_debug:
|
|
data = data[:NUM_DEBUG]
|
|
examples = []
|
|
for ex in data:
|
|
examples.append((ex['text'], ex['label']))
|
|
return examples
|
|
|
|
|
|
def process_agnews_dataset(data,
|
|
split,
|
|
tokenizer,
|
|
max_seq_len,
|
|
cache_dir='',
|
|
client_id=None,
|
|
pretrain=False,
|
|
is_debug=False,
|
|
**kwargs):
|
|
if pretrain:
|
|
return process_agnews_dataset_for_pretrain(data, split, tokenizer,
|
|
max_seq_len, cache_dir,
|
|
client_id, is_debug)
|
|
|
|
save_dir = osp.join(cache_dir, 'train', str(client_id))
|
|
cache_file = osp.join(save_dir, split + '.pt')
|
|
if osp.exists(cache_file):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_file))
|
|
cache_data = torch.load(cache_file)
|
|
examples = cache_data['examples']
|
|
encoded_inputs = cache_data['encoded_inputs']
|
|
else:
|
|
examples = get_agnews_examples(data, is_debug)
|
|
texts = [ex[0] for ex in examples]
|
|
encoded_inputs = tokenizer(texts,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_seq_len,
|
|
return_tensors='pt')
|
|
|
|
if cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_file))
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
torch.save({
|
|
'examples': examples,
|
|
'encoded_inputs': encoded_inputs
|
|
}, cache_file)
|
|
|
|
labels = [ex[1] for ex in examples]
|
|
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
|
|
dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': encoded_inputs.input_ids,
|
|
'token_type_ids': encoded_inputs.token_type_ids,
|
|
'attention_mask': encoded_inputs.attention_mask,
|
|
'labels': torch.LongTensor(labels),
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, encoded_inputs, examples
|
|
|
|
|
|
def process_agnews_dataset_for_pretrain(data,
|
|
split,
|
|
tokenizer,
|
|
max_seq_len,
|
|
cache_dir='',
|
|
client_id=None,
|
|
is_debug=False):
|
|
save_dir = osp.join(cache_dir, 'pretrain', str(client_id))
|
|
cache_file = osp.join(save_dir, split + '.pt')
|
|
|
|
if osp.exists(cache_file):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_file))
|
|
cache_data = torch.load(cache_file)
|
|
examples = cache_data['examples']
|
|
encoded_inputs = cache_data['encoded_inputs']
|
|
else:
|
|
examples = get_agnews_examples(data, is_debug)
|
|
texts = [ex[0] for ex in examples]
|
|
texts = split_sent(texts, eoq=tokenizer.eoq_token)
|
|
encoded_inputs = tokenizer(texts,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_seq_len,
|
|
return_tensors='pt')
|
|
num_non_padding = (encoded_inputs.input_ids !=
|
|
tokenizer.pad_token_id).sum(dim=-1)
|
|
for i, pad_idx in enumerate(num_non_padding):
|
|
encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id
|
|
encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
|
|
|
|
if cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_file))
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
torch.save({
|
|
'examples': examples,
|
|
'encoded_inputs': encoded_inputs
|
|
}, cache_file)
|
|
|
|
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
|
|
dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': encoded_inputs.input_ids,
|
|
'attention_mask': encoded_inputs.attention_mask,
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, encoded_inputs, examples
|