130 lines
3.8 KiB
Python
130 lines
3.8 KiB
Python
import os
|
|
import json
|
|
import numpy as np
|
|
import logging
|
|
|
|
try:
|
|
import torch
|
|
from torch.utils.data.dataset import Dataset
|
|
except ImportError:
|
|
torch = None
|
|
Dataset = None
|
|
|
|
NUM_DEBUG = 20
|
|
BOS_TOKEN_ID = -1
|
|
EOS_TOKEN_ID = -1
|
|
EOQ_TOKEN_ID = -1
|
|
PAD_TOKEN_ID = -1
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def split_sent(examples, eoq='[unused2]', tokenize=True):
|
|
import nltk
|
|
from nltk.tokenize import sent_tokenize
|
|
|
|
try:
|
|
nltk.data.find('tokenizers/punkt')
|
|
except LookupError:
|
|
nltk.download('punkt')
|
|
|
|
new_examples = []
|
|
for e in examples:
|
|
if tokenize:
|
|
e = f' {eoq} '.join(sent_tokenize(e))
|
|
else:
|
|
e = e.replace('[SEP]', eoq)
|
|
new_examples.append(e)
|
|
return new_examples
|
|
|
|
|
|
class DatasetDict(Dataset):
|
|
def __init__(self, inputs):
|
|
super().__init__()
|
|
assert all(
|
|
list(inputs.values())[0].size(0) == v.size(0)
|
|
for v in inputs.values()), "Size mismatch between tensors"
|
|
self.inputs = inputs
|
|
|
|
def __getitem__(self, index):
|
|
return {k: v[index] for k, v in self.inputs.items()}
|
|
|
|
def __len__(self):
|
|
return list(self.inputs.values())[0].size(0)
|
|
|
|
|
|
def setup_tokenizer(model_type,
|
|
bos_token='[unused0]',
|
|
eos_token='[unused1]',
|
|
eoq_token='[unused2]'):
|
|
"""
|
|
Get a tokenizer, the default bos/eos/eoq token is used for Bert
|
|
"""
|
|
from transformers.models.bert import BertTokenizerFast
|
|
try:
|
|
tokenizer = BertTokenizerFast.from_pretrained(
|
|
model_type,
|
|
additional_special_tokens=[bos_token, eos_token, eoq_token],
|
|
skip_special_tokens=True,
|
|
local_files_only=True,
|
|
)
|
|
except:
|
|
tokenizer = BertTokenizerFast.from_pretrained(
|
|
model_type,
|
|
additional_special_tokens=[bos_token, eos_token, eoq_token],
|
|
skip_special_tokens=True,
|
|
)
|
|
|
|
tokenizer.bos_token = bos_token
|
|
tokenizer.eos_token = eos_token
|
|
tokenizer.eoq_token = eoq_token
|
|
tokenizer.bos_token_id = tokenizer.vocab[bos_token]
|
|
tokenizer.eos_token_id = tokenizer.vocab[eos_token]
|
|
tokenizer.eoq_token_id = tokenizer.vocab[eoq_token]
|
|
|
|
global BOS_TOKEN_ID, EOS_TOKEN_ID, EOQ_TOKEN_ID, PAD_TOKEN_ID
|
|
BOS_TOKEN_ID = tokenizer.bos_token_id
|
|
EOS_TOKEN_ID = tokenizer.eos_token_id
|
|
EOQ_TOKEN_ID = tokenizer.eoq_token_id
|
|
PAD_TOKEN_ID = tokenizer.pad_token_id
|
|
|
|
return tokenizer
|
|
|
|
|
|
def load_synth_data(data_config):
|
|
"""
|
|
Load the synthetic data for contrastive learning
|
|
"""
|
|
if data_config.is_debug:
|
|
synth_dir = 'cache_debug/synthetic/'
|
|
else:
|
|
synth_dir = os.path.join(data_config.cache_dir, 'synthetic')
|
|
|
|
logger.info('Loading synthetic data from \'{}\''.format(synth_dir))
|
|
synth_prim_weight = data_config.hetero_synth_prim_weight
|
|
with open(os.path.join(synth_dir, 'shapes.json')) as f:
|
|
shapes = json.load(f)
|
|
synth_feat_path = os.path.join(
|
|
synth_dir, 'feature_{}.memmap'.format(synth_prim_weight))
|
|
synth_tok_path = os.path.join(synth_dir,
|
|
'token_{}.memmap'.format(synth_prim_weight))
|
|
synth_feats = np.memmap(filename=synth_feat_path,
|
|
shape=tuple(shapes['feature']),
|
|
mode='r',
|
|
dtype=np.float32)
|
|
synth_toks = np.memmap(filename=synth_tok_path,
|
|
shape=tuple(shapes['token']),
|
|
mode='r',
|
|
dtype=np.int64)
|
|
num_contrast = data_config.num_contrast
|
|
synth_feats = {
|
|
k: v
|
|
for k, v in enumerate(torch.from_numpy(synth_feats)[:num_contrast])
|
|
}
|
|
synth_toks = {
|
|
k: v
|
|
for k, v in enumerate(torch.from_numpy(synth_toks)[:num_contrast])
|
|
}
|
|
|
|
return synth_feats, synth_toks
|