203 lines
9.1 KiB
Python
203 lines
9.1 KiB
Python
import os
|
|
import os.path as osp
|
|
import logging
|
|
import torch
|
|
import numpy as np
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
|
|
DatasetDict, NUM_DEBUG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def get_msqg_examples(data, is_debug=False):
|
|
if is_debug:
|
|
data = data[:NUM_DEBUG]
|
|
src_examples, tgt_examples = [], []
|
|
for ex in data:
|
|
src_examples.append(ex['src'])
|
|
tgt_examples.append(ex['tgt'])
|
|
return src_examples, tgt_examples
|
|
|
|
|
|
def process_msqg_dataset(data,
|
|
split,
|
|
tokenizer,
|
|
max_src_len,
|
|
max_tgt_len,
|
|
raw_cache_dir='',
|
|
client_id=None,
|
|
pretrain=False,
|
|
is_debug=False,
|
|
**kwargs):
|
|
if pretrain:
|
|
return process_msqg_dataset_for_pretrain(data, split, tokenizer,
|
|
max_src_len, raw_cache_dir,
|
|
client_id, is_debug)
|
|
|
|
cache_dir = osp.join(raw_cache_dir, 'train', str(client_id), split)
|
|
src_examples, tgt_examples = get_msqg_examples(data, is_debug)
|
|
if osp.exists(cache_dir):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_dir))
|
|
token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
token_type_ids = np.memmap(filename=osp.join(cache_dir,
|
|
'token_type_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
attention_mask = np.memmap(filename=osp.join(cache_dir,
|
|
'attention_mask.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'),
|
|
shape=(len(src_examples), max_tgt_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
|
|
token_ids = torch.from_numpy(token_ids)
|
|
token_type_ids = torch.from_numpy(token_type_ids)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
labels = torch.from_numpy(labels)
|
|
else:
|
|
src_encoded = tokenizer(src_examples,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_src_len,
|
|
return_tensors='pt')
|
|
tgt_examples = split_sent(tgt_examples,
|
|
eoq=tokenizer.eoq_token,
|
|
tokenize=False)
|
|
tgt_encoded = tokenizer(tgt_examples,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_tgt_len,
|
|
return_tensors='pt')
|
|
num_non_padding = (tgt_encoded.input_ids !=
|
|
tokenizer.pad_token_id).sum(dim=-1)
|
|
for i, pad_idx in enumerate(num_non_padding):
|
|
tgt_encoded.input_ids[i, 0] = tokenizer.bos_token_id
|
|
tgt_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
|
|
|
|
if raw_cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_dir))
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
token_ids = np.memmap(filename=osp.join(cache_dir,
|
|
'token_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
token_type_ids = np.memmap(filename=osp.join(
|
|
cache_dir, 'token_type_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
attention_mask = np.memmap(filename=osp.join(
|
|
cache_dir, 'attention_mask.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
labels = np.memmap(filename=osp.join(cache_dir, 'labels.memmap'),
|
|
shape=(len(src_examples), max_tgt_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
|
|
for i in range(len(src_examples)):
|
|
token_ids[i] = src_encoded.input_ids[i]
|
|
token_type_ids[i] = src_encoded.token_type_ids[i]
|
|
attention_mask[i] = src_encoded.attention_mask[i]
|
|
labels[i] = tgt_encoded.input_ids[i]
|
|
|
|
token_ids = torch.from_numpy(token_ids)
|
|
token_type_ids = torch.from_numpy(token_type_ids)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
labels = torch.from_numpy(labels)
|
|
|
|
else:
|
|
token_ids = src_encoded.input_ids
|
|
token_type_ids = src_encoded.token_type_ids
|
|
attention_mask = src_encoded.attention_mask
|
|
labels = tgt_encoded.input_ids
|
|
|
|
example_indices = torch.arange(token_ids.size(0), dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': token_ids,
|
|
'token_type_ids': token_type_ids,
|
|
'attention_mask': attention_mask,
|
|
'labels': labels,
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, None, None
|
|
|
|
|
|
def process_msqg_dataset_for_pretrain(data,
|
|
split,
|
|
tokenizer,
|
|
max_src_len,
|
|
raw_cache_dir='',
|
|
client_id=None,
|
|
is_debug=False):
|
|
cache_dir = osp.join(raw_cache_dir, 'pretrain', str(client_id), split)
|
|
src_examples, tgt_examples = get_msqg_examples(data, is_debug)
|
|
if osp.exists(cache_dir):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_dir))
|
|
token_ids = np.memmap(filename=osp.join(cache_dir, 'token_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
attention_mask = np.memmap(filename=osp.join(cache_dir,
|
|
'attention_mask.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
token_ids = torch.from_numpy(token_ids)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
else:
|
|
src_examples = split_sent(src_examples,
|
|
eoq=tokenizer.eoq_token,
|
|
tokenize=False)
|
|
src_encoded = tokenizer(src_examples,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_src_len,
|
|
return_tensors='pt')
|
|
num_non_padding = (src_encoded.input_ids !=
|
|
tokenizer.pad_token_id).sum(dim=-1)
|
|
for i, pad_idx in enumerate(num_non_padding):
|
|
src_encoded.input_ids[i, 0] = tokenizer.bos_token_id
|
|
src_encoded.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
|
|
|
|
if raw_cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_dir))
|
|
os.makedirs(cache_dir, exist_ok=True)
|
|
token_ids = np.memmap(filename=osp.join(cache_dir,
|
|
'token_ids.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
attention_mask = np.memmap(filename=osp.join(
|
|
cache_dir, 'attention_mask.memmap'),
|
|
shape=(len(src_examples), max_src_len),
|
|
mode='w+',
|
|
dtype=np.int64)
|
|
|
|
for i in range(len(src_examples)):
|
|
token_ids[i] = src_encoded.input_ids[i]
|
|
attention_mask[i] = src_encoded.attention_mask[i]
|
|
|
|
token_ids = torch.from_numpy(token_ids)
|
|
attention_mask = torch.from_numpy(attention_mask)
|
|
else:
|
|
token_ids = src_encoded.input_ids
|
|
attention_mask = src_encoded.attention_mask
|
|
|
|
example_indices = torch.arange(token_ids.size(0), dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': token_ids,
|
|
'attention_mask': attention_mask,
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, None, None
|