import os import json import numpy as np import logging try: import torch from torch.utils.data.dataset import Dataset except ImportError: torch = None Dataset = None NUM_DEBUG = 20 BOS_TOKEN_ID = -1 EOS_TOKEN_ID = -1 EOQ_TOKEN_ID = -1 PAD_TOKEN_ID = -1 logger = logging.getLogger(__name__) def split_sent(examples, eoq='[unused2]', tokenize=True): import nltk from nltk.tokenize import sent_tokenize try: nltk.data.find('tokenizers/punkt') except LookupError: nltk.download('punkt') new_examples = [] for e in examples: if tokenize: e = f' {eoq} '.join(sent_tokenize(e)) else: e = e.replace('[SEP]', eoq) new_examples.append(e) return new_examples class DatasetDict(Dataset): def __init__(self, inputs): super().__init__() assert all( list(inputs.values())[0].size(0) == v.size(0) for v in inputs.values()), "Size mismatch between tensors" self.inputs = inputs def __getitem__(self, index): return {k: v[index] for k, v in self.inputs.items()} def __len__(self): return list(self.inputs.values())[0].size(0) def setup_tokenizer(model_type, bos_token='[unused0]', eos_token='[unused1]', eoq_token='[unused2]'): """ Get a tokenizer, the default bos/eos/eoq token is used for Bert """ from transformers.models.bert import BertTokenizerFast try: tokenizer = BertTokenizerFast.from_pretrained( model_type, additional_special_tokens=[bos_token, eos_token, eoq_token], skip_special_tokens=True, local_files_only=True, ) except: tokenizer = BertTokenizerFast.from_pretrained( model_type, additional_special_tokens=[bos_token, eos_token, eoq_token], skip_special_tokens=True, ) tokenizer.bos_token = bos_token tokenizer.eos_token = eos_token tokenizer.eoq_token = eoq_token tokenizer.bos_token_id = tokenizer.vocab[bos_token] tokenizer.eos_token_id = tokenizer.vocab[eos_token] tokenizer.eoq_token_id = tokenizer.vocab[eoq_token] global BOS_TOKEN_ID, EOS_TOKEN_ID, EOQ_TOKEN_ID, PAD_TOKEN_ID BOS_TOKEN_ID = tokenizer.bos_token_id EOS_TOKEN_ID = tokenizer.eos_token_id EOQ_TOKEN_ID = tokenizer.eoq_token_id PAD_TOKEN_ID = tokenizer.pad_token_id return tokenizer def load_synth_data(data_config): """ Load the synthetic data for contrastive learning """ if data_config.is_debug: synth_dir = 'cache_debug/synthetic/' else: synth_dir = os.path.join(data_config.cache_dir, 'synthetic') logger.info('Loading synthetic data from \'{}\''.format(synth_dir)) synth_prim_weight = data_config.hetero_synth_prim_weight with open(os.path.join(synth_dir, 'shapes.json')) as f: shapes = json.load(f) synth_feat_path = os.path.join( synth_dir, 'feature_{}.memmap'.format(synth_prim_weight)) synth_tok_path = os.path.join(synth_dir, 'token_{}.memmap'.format(synth_prim_weight)) synth_feats = np.memmap(filename=synth_feat_path, shape=tuple(shapes['feature']), mode='r', dtype=np.float32) synth_toks = np.memmap(filename=synth_tok_path, shape=tuple(shapes['token']), mode='r', dtype=np.int64) num_contrast = data_config.num_contrast synth_feats = { k: v for k, v in enumerate(torch.from_numpy(synth_feats)[:num_contrast]) } synth_toks = { k: v for k, v in enumerate(torch.from_numpy(synth_toks)[:num_contrast]) } return synth_feats, synth_toks