FS-TFP/federatedscope/nlp/hetero_tasks/dataset/utils.py

130 lines
3.8 KiB
Python

import os
import json
import numpy as np
import logging
try:
import torch
from torch.utils.data.dataset import Dataset
except ImportError:
torch = None
Dataset = None
NUM_DEBUG = 20
BOS_TOKEN_ID = -1
EOS_TOKEN_ID = -1
EOQ_TOKEN_ID = -1
PAD_TOKEN_ID = -1
logger = logging.getLogger(__name__)
def split_sent(examples, eoq='[unused2]', tokenize=True):
import nltk
from nltk.tokenize import sent_tokenize
try:
nltk.data.find('tokenizers/punkt')
except LookupError:
nltk.download('punkt')
new_examples = []
for e in examples:
if tokenize:
e = f' {eoq} '.join(sent_tokenize(e))
else:
e = e.replace('[SEP]', eoq)
new_examples.append(e)
return new_examples
class DatasetDict(Dataset):
def __init__(self, inputs):
super().__init__()
assert all(
list(inputs.values())[0].size(0) == v.size(0)
for v in inputs.values()), "Size mismatch between tensors"
self.inputs = inputs
def __getitem__(self, index):
return {k: v[index] for k, v in self.inputs.items()}
def __len__(self):
return list(self.inputs.values())[0].size(0)
def setup_tokenizer(model_type,
bos_token='[unused0]',
eos_token='[unused1]',
eoq_token='[unused2]'):
"""
Get a tokenizer, the default bos/eos/eoq token is used for Bert
"""
from transformers.models.bert import BertTokenizerFast
try:
tokenizer = BertTokenizerFast.from_pretrained(
model_type,
additional_special_tokens=[bos_token, eos_token, eoq_token],
skip_special_tokens=True,
local_files_only=True,
)
except:
tokenizer = BertTokenizerFast.from_pretrained(
model_type,
additional_special_tokens=[bos_token, eos_token, eoq_token],
skip_special_tokens=True,
)
tokenizer.bos_token = bos_token
tokenizer.eos_token = eos_token
tokenizer.eoq_token = eoq_token
tokenizer.bos_token_id = tokenizer.vocab[bos_token]
tokenizer.eos_token_id = tokenizer.vocab[eos_token]
tokenizer.eoq_token_id = tokenizer.vocab[eoq_token]
global BOS_TOKEN_ID, EOS_TOKEN_ID, EOQ_TOKEN_ID, PAD_TOKEN_ID
BOS_TOKEN_ID = tokenizer.bos_token_id
EOS_TOKEN_ID = tokenizer.eos_token_id
EOQ_TOKEN_ID = tokenizer.eoq_token_id
PAD_TOKEN_ID = tokenizer.pad_token_id
return tokenizer
def load_synth_data(data_config):
"""
Load the synthetic data for contrastive learning
"""
if data_config.is_debug:
synth_dir = 'cache_debug/synthetic/'
else:
synth_dir = os.path.join(data_config.cache_dir, 'synthetic')
logger.info('Loading synthetic data from \'{}\''.format(synth_dir))
synth_prim_weight = data_config.hetero_synth_prim_weight
with open(os.path.join(synth_dir, 'shapes.json')) as f:
shapes = json.load(f)
synth_feat_path = os.path.join(
synth_dir, 'feature_{}.memmap'.format(synth_prim_weight))
synth_tok_path = os.path.join(synth_dir,
'token_{}.memmap'.format(synth_prim_weight))
synth_feats = np.memmap(filename=synth_feat_path,
shape=tuple(shapes['feature']),
mode='r',
dtype=np.float32)
synth_toks = np.memmap(filename=synth_tok_path,
shape=tuple(shapes['token']),
mode='r',
dtype=np.int64)
num_contrast = data_config.num_contrast
synth_feats = {
k: v
for k, v in enumerate(torch.from_numpy(synth_feats)[:num_contrast])
}
synth_toks = {
k: v
for k, v in enumerate(torch.from_numpy(synth_toks)[:num_contrast])
}
return synth_feats, synth_toks