FS-TFP/federatedscope/nlp/hetero_tasks/dataset/squad.py

388 lines
16 KiB
Python

import os
import os.path as osp
import torch
import logging
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
DatasetDict, NUM_DEBUG
logger = logging.getLogger(__name__)
class SquadExample(object):
def __init__(self, qa_id, question, context, train_answer, val_answer,
start_pos, end_pos, context_tokens, is_impossible):
self.qa_id = qa_id
self.question = question
self.context = context
self.train_answer = train_answer
self.val_answer = val_answer
self.start_position = start_pos
self.end_position = end_pos
self.context_tokens = context_tokens
self.is_impossible = is_impossible
class SquadEncodedInput(object):
def __init__(self, token_ids, token_type_ids, attention_mask,
overflow_token_ids):
self.token_ids = token_ids
self.token_type_ids = token_type_ids
self.attention_mask = attention_mask
self.overflow_token_ids = overflow_token_ids
class SquadResult(object):
def __init__(self, unique_id, start_logits, end_logits):
self.unique_id = unique_id
self.start_logits = start_logits
self.end_logits = end_logits
def refine_subtoken_position(context_subtokens, subtoken_start_pos,
subtoken_end_pos, tokenizer, annotated_answer):
subtoken_answer = ' '.join(tokenizer.tokenize(annotated_answer))
for new_st in range(subtoken_start_pos, subtoken_end_pos + 1):
for new_ed in range(subtoken_end_pos, subtoken_start_pos - 1, -1):
text_span = ' '.join(context_subtokens[new_st:(new_ed + 1)])
if text_span == subtoken_answer:
return new_st, new_ed
return subtoken_start_pos, subtoken_end_pos
def get_char_to_word_positions(context, answer, start_char_pos, is_impossible):
context_tokens = []
char_to_word_offset = []
is_prev_whitespace = True
for c in context:
is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n'
or ord(c) == 0x202F)
if is_whitespace:
is_prev_whitespace = True
else:
if is_prev_whitespace:
context_tokens.append(c)
else:
context_tokens[-1] += c
is_prev_whitespace = False
char_to_word_offset.append(len(context_tokens) - 1)
start_pos, end_pos = 0, 0
if start_char_pos is not None and not is_impossible:
start_pos = char_to_word_offset[start_char_pos]
end_pos = char_to_word_offset[start_char_pos + len(answer) - 1]
return start_pos, end_pos, context_tokens
def check_max_context_token(all_spans, cur_span_idx, pos):
best_score, best_span_idx = None, None
for span_idx, span in enumerate(all_spans):
end = span.context_start_position + span.context_len - 1
if pos < span.context_start_position or pos > end:
continue
num_left_context = pos - span.context_start_position
num_right_context = end - pos
score = \
min(num_left_context, num_right_context) + 0.01 * span.context_len
if best_score is None or score > best_score:
best_score = score
best_span_idx = span_idx
return cur_span_idx == best_span_idx
def encode(tokenizer, text_a, text_b, max_seq_len, max_query_len,
added_trunc_size):
def _get_token_ids(text):
if isinstance(text, str):
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
elif isinstance(text, (list, tuple)) and len(text) > 0 and \
isinstance(text[0], str):
return tokenizer.convert_tokens_to_ids(text)
elif isinstance(text, (list, tuple)) and len(text) > 0 and \
isinstance(text[0], int):
return text
else:
raise ValueError('Input is not valid, should be a string, '
'a list/tuple of strings or a list/tuple of '
'integers.')
token_ids_a = _get_token_ids(text_a)
token_ids_b = _get_token_ids(text_b)
# Truncate
overflow_token_ids = None
len_a = len(token_ids_a) + 2
total_len = len(token_ids_a) + len(token_ids_b) + 3
if len_a > max_query_len:
num_remove = len_a - max_query_len
token_ids_a = token_ids_a[:-num_remove]
if total_len > max_seq_len:
num_remove = total_len - max_seq_len
trunc_size = min(len(token_ids_b), added_trunc_size + num_remove)
overflow_token_ids = token_ids_b[-trunc_size:]
token_ids_b = token_ids_b[:-num_remove]
# Combine and pad
token_ids = \
[tokenizer.cls_token_id] + token_ids_a + [tokenizer.sep_token_id]
token_type_ids = [0] * len(token_ids)
token_ids += token_ids_b + [tokenizer.sep_token_id]
token_type_ids += [1] * (len(token_ids_b) + 1)
attention_mask = [1] * len(token_ids)
if len(token_ids) < max_seq_len:
dif = max_seq_len - len(token_ids)
token_ids += [tokenizer.pad_token_id] * dif
token_type_ids += [0] * dif
attention_mask += [0] * dif
return SquadEncodedInput(token_ids, token_type_ids, attention_mask,
overflow_token_ids)
def get_squad_examples(data, split, is_debug=False):
if is_debug:
data = data[:NUM_DEBUG]
examples = []
for para in data:
context = para['context']
qa = para['qa']
qa_id = qa['id']
question = qa['question']
start_char_pos = None
train_answer = None
val_answer = []
is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False
if not is_impossible:
if split == 'train':
train_answer = qa['answers'][0]['text']
start_char_pos = qa['answers'][0]['answer_start']
else:
val_answer = qa['answers']
start_pos, end_pos, context_tokens = get_char_to_word_positions(
context, train_answer, start_char_pos, is_impossible)
examples.append(
SquadExample(qa_id, question, context, train_answer, val_answer,
start_pos, end_pos, context_tokens, is_impossible))
return examples
def process_squad_dataset(data,
split,
tokenizer,
max_seq_len,
max_query_len,
trunc_stride,
cache_dir='',
client_id=None,
pretrain=False,
is_debug=False,
**kwargs):
if pretrain:
return process_squad_dataset_for_pretrain(data, split, tokenizer,
max_seq_len, cache_dir,
client_id, is_debug)
save_dir = osp.join(cache_dir, 'train', str(client_id))
cache_file = osp.join(save_dir, split + '.pt')
if osp.exists(cache_file):
logger.info('Loading cache file from \'{}\''.format(cache_file))
cache_data = torch.load(cache_file)
examples = cache_data['examples']
encoded_inputs = cache_data['encoded_inputs']
else:
examples = get_squad_examples(data, split, is_debug)
unique_id = 1000000000
encoded_inputs = []
for example_idx, example in enumerate(examples):
if split == 'train' and not example.is_impossible:
start_pos = example.start_position
end_pos = example.end_position
actual_answer = ' '.join(
example.context_tokens[start_pos:(end_pos + 1)])
cleaned_answer = ' '.join(example.train_answer.strip().split())
if actual_answer.find(cleaned_answer) == -1:
logger.info('Could not find answer: {} vs. {}'.format(
actual_answer, cleaned_answer))
continue
tok_to_subtok_idx = []
subtok_to_tok_idx = []
context_subtokens = []
for i, token in enumerate(example.context_tokens):
tok_to_subtok_idx.append(len(context_subtokens))
subtokens = tokenizer.tokenize(token)
for subtoken in subtokens:
subtok_to_tok_idx.append(i)
context_subtokens.append(subtoken)
if split == 'train' and not example.is_impossible:
subtoken_start_pos = tok_to_subtok_idx[example.start_position]
if example.end_position < len(example.context_tokens) - 1:
subtoken_end_pos = tok_to_subtok_idx[example.end_position +
1] - 1
else:
subtoken_end_pos = len(context_subtokens) - 1
subtoken_start_pos, subtoken_end_pos = \
refine_subtoken_position(context_subtokens,
subtoken_start_pos,
subtoken_end_pos,
tokenizer,
example.train_answer)
truncated_context = context_subtokens
len_question = min(len(tokenizer.tokenize(example.question)),
max_query_len - 2)
added_trunc_size = max_seq_len - trunc_stride - len_question - 3
spans = []
while len(spans) * trunc_stride < len(context_subtokens):
text_a = example.question
text_b = truncated_context
encoded_input = encode(tokenizer, text_a, text_b, max_seq_len,
max_query_len, added_trunc_size)
context_start_pos = len(spans) * trunc_stride
context_len = min(
len(context_subtokens) - context_start_pos,
max_seq_len - len_question - 3)
context_end_pos = context_start_pos + context_len - 1
if tokenizer.pad_token_id in encoded_input.token_ids:
non_padded_ids = encoded_input.token_ids[:encoded_input.
token_ids.index(
tokenizer.
pad_token_id)]
else:
non_padded_ids = encoded_input.token_ids
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
context_subtok_to_tok_idx = {}
for i in range(context_len):
context_idx = len_question + i + 2
context_subtok_to_tok_idx[context_idx] = \
subtok_to_tok_idx[context_start_pos + i]
start_pos, end_pos = 0, 0
span_is_impossible = example.is_impossible
if split == 'train' and not span_is_impossible:
# For training, if our document chunk does not contain
# an annotation we throw it out, since there is nothing
# to predict.
if subtoken_start_pos >= context_start_pos and \
subtoken_end_pos <= context_end_pos:
context_offset = len_question + 2
start_pos = \
subtoken_start_pos - context_start_pos + \
context_offset
end_pos = \
subtoken_end_pos - context_start_pos + \
context_offset
else:
start_pos = 0
end_pos = 0
span_is_impossible = True
encoded_input.start_position = start_pos
encoded_input.end_position = end_pos
encoded_input.is_impossible = span_is_impossible
# For computing metrics
encoded_input.example_index = example_idx
encoded_input.context_start_position = context_start_pos
encoded_input.context_len = context_len
encoded_input.tokens = tokens
encoded_input.context_subtok_to_tok_idx = \
context_subtok_to_tok_idx
encoded_input.is_max_context_token = {}
encoded_input.unique_id = unique_id
spans.append(encoded_input)
unique_id += 1
if encoded_input.overflow_token_ids is None:
break
truncated_context = encoded_input.overflow_token_ids
for span_idx in range(len(spans)):
for context_idx in range(spans[span_idx].context_len):
is_max_context_token = check_max_context_token(
spans, span_idx, span_idx * trunc_stride + context_idx)
idx = len_question + context_idx + 2
spans[span_idx].is_max_context_token[idx] = \
is_max_context_token
encoded_inputs.extend(spans)
if cache_dir:
logger.info('Saving cache file to \'{}\''.format(cache_file))
os.makedirs(save_dir, exist_ok=True)
torch.save({
'examples': examples,
'encoded_inputs': encoded_inputs
}, cache_file)
token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs])
token_type_ids = torch.LongTensor(
[inp.token_type_ids for inp in encoded_inputs])
attention_mask = torch.LongTensor(
[inp.attention_mask for inp in encoded_inputs])
start_positions = torch.LongTensor(
[inp.start_position for inp in encoded_inputs])
end_positions = torch.LongTensor(
[inp.end_position for inp in encoded_inputs])
example_indices = torch.arange(token_ids.size(0), dtype=torch.long)
dataset = DatasetDict({
'token_ids': token_ids,
'token_type_ids': token_type_ids,
'attention_mask': attention_mask,
'start_positions': start_positions,
'end_positions': end_positions,
'example_indices': example_indices
})
return dataset, encoded_inputs, examples
def process_squad_dataset_for_pretrain(data,
split,
tokenizer,
max_seq_len,
cache_dir='',
client_id=None,
is_debug=False):
save_dir = osp.join(cache_dir, 'pretrain', str(client_id))
cache_file = osp.join(save_dir, split + '.pt')
if osp.exists(cache_file):
logger.info('Loading cache file from \'{}\''.format(cache_file))
cache_data = torch.load(cache_file)
examples = cache_data['examples']
encoded_inputs = cache_data['encoded_inputs']
else:
examples = get_squad_examples(data, split, is_debug)
texts = split_sent([e.context for e in examples],
eoq=tokenizer.eoq_token)
encoded_inputs = tokenizer(texts,
padding='max_length',
truncation=True,
max_length=max_seq_len,
return_tensors='pt')
num_non_padding = (encoded_inputs.input_ids !=
tokenizer.pad_token_id).sum(dim=-1)
for i, pad_idx in enumerate(num_non_padding):
encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id
encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
if cache_dir:
logger.info('Saving cache file to \'{}\''.format(cache_file))
os.makedirs(save_dir, exist_ok=True)
torch.save({
'examples': examples,
'encoded_inputs': encoded_inputs
}, cache_file)
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
dtype=torch.long)
dataset = DatasetDict({
'token_ids': encoded_inputs.input_ids,
'attention_mask': encoded_inputs.attention_mask,
'example_indices': example_indices
})
return dataset, encoded_inputs, examples