394 lines
17 KiB
Python
394 lines
17 KiB
Python
import os
|
|
import os.path as osp
|
|
import torch
|
|
import logging
|
|
from federatedscope.nlp.hetero_tasks.dataset.utils import split_sent, \
|
|
DatasetDict, NUM_DEBUG
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class NewsQAExample(object):
|
|
def __init__(self, qa_id, question, context, train_answer, val_answer,
|
|
start_pos, end_pos, context_tokens, is_impossible):
|
|
self.qa_id = qa_id
|
|
self.question = question
|
|
self.context = context
|
|
self.train_answer = train_answer
|
|
self.val_answer = val_answer
|
|
self.start_position = start_pos
|
|
self.end_position = end_pos
|
|
self.context_tokens = context_tokens
|
|
self.is_impossible = is_impossible
|
|
|
|
|
|
class NewsQAEncodedInput(object):
|
|
def __init__(self, token_ids, token_type_ids, attention_mask,
|
|
overflow_token_ids):
|
|
self.token_ids = token_ids
|
|
self.token_type_ids = token_type_ids
|
|
self.attention_mask = attention_mask
|
|
self.overflow_token_ids = overflow_token_ids
|
|
|
|
|
|
class NewsQAResult(object):
|
|
def __init__(self, unique_id, start_logits, end_logits):
|
|
self.unique_id = unique_id
|
|
self.start_logits = start_logits
|
|
self.end_logits = end_logits
|
|
|
|
|
|
def refine_subtoken_position(context_subtokens, subtoken_start_pos,
|
|
subtoken_end_pos, tokenizer, annotated_answer):
|
|
subtoken_answer = ' '.join(tokenizer.tokenize(annotated_answer))
|
|
for new_st in range(subtoken_start_pos, subtoken_end_pos + 1):
|
|
for new_ed in range(subtoken_end_pos, subtoken_start_pos - 1, -1):
|
|
text_span = ' '.join(context_subtokens[new_st:(new_ed + 1)])
|
|
if text_span == subtoken_answer:
|
|
return new_st, new_ed
|
|
return subtoken_start_pos, subtoken_end_pos
|
|
|
|
|
|
def get_char_to_word_positions(context, answer, start_char_pos, is_impossible):
|
|
context_tokens = []
|
|
char_to_word_offset = []
|
|
is_prev_whitespace = True
|
|
for c in context:
|
|
is_whitespace = (c == ' ' or c == '\t' or c == '\r' or c == '\n'
|
|
or ord(c) == 0x202F)
|
|
if is_whitespace:
|
|
is_prev_whitespace = True
|
|
else:
|
|
if is_prev_whitespace:
|
|
context_tokens.append(c)
|
|
else:
|
|
context_tokens[-1] += c
|
|
is_prev_whitespace = False
|
|
char_to_word_offset.append(len(context_tokens) - 1)
|
|
|
|
start_pos, end_pos = 0, 0
|
|
if start_char_pos is not None and not is_impossible:
|
|
start_pos = char_to_word_offset[start_char_pos]
|
|
end_pos = char_to_word_offset[start_char_pos + len(answer) - 1]
|
|
return start_pos, end_pos, context_tokens
|
|
|
|
|
|
def check_max_context_token(all_spans, cur_span_idx, pos):
|
|
best_score, best_span_idx = None, None
|
|
for span_idx, span in enumerate(all_spans):
|
|
end = span.context_start_position + span.context_len - 1
|
|
if pos < span.context_start_position or pos > end:
|
|
continue
|
|
num_left_context = pos - span.context_start_position
|
|
num_right_context = end - pos
|
|
score = \
|
|
min(num_left_context, num_right_context) + 0.01 * span.context_len
|
|
if best_score is None or score > best_score:
|
|
best_score = score
|
|
best_span_idx = span_idx
|
|
return cur_span_idx == best_span_idx
|
|
|
|
|
|
def encode(tokenizer, text_a, text_b, max_seq_len, max_query_len,
|
|
added_trunc_size):
|
|
def _get_token_ids(text):
|
|
if isinstance(text, str):
|
|
return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
|
|
elif isinstance(text, (list, tuple)) and len(text) > 0 and \
|
|
isinstance(text[0], str):
|
|
return tokenizer.convert_tokens_to_ids(text)
|
|
elif isinstance(text, (list, tuple)) and len(text) > 0 and \
|
|
isinstance(text[0], int):
|
|
return text
|
|
else:
|
|
raise ValueError('Input is not valid, should be a string, '
|
|
'a list/tuple of strings or a list/tuple of '
|
|
'integers.')
|
|
|
|
token_ids_a = _get_token_ids(text_a)
|
|
token_ids_b = _get_token_ids(text_b)
|
|
|
|
# Truncate
|
|
overflow_token_ids = None
|
|
len_a = len(token_ids_a) + 2
|
|
total_len = len(token_ids_a) + len(token_ids_b) + 3
|
|
if len_a > max_query_len:
|
|
num_remove = len_a - max_query_len
|
|
token_ids_a = token_ids_a[:-num_remove]
|
|
if total_len > max_seq_len:
|
|
num_remove = total_len - max_seq_len
|
|
trunc_size = min(len(token_ids_b), added_trunc_size + num_remove)
|
|
overflow_token_ids = token_ids_b[-trunc_size:]
|
|
token_ids_b = token_ids_b[:-num_remove]
|
|
|
|
# Combine and pad
|
|
token_ids = \
|
|
[tokenizer.cls_token_id] + token_ids_a + [tokenizer.sep_token_id]
|
|
token_type_ids = [0] * len(token_ids)
|
|
token_ids += token_ids_b + [tokenizer.sep_token_id]
|
|
token_type_ids += [1] * (len(token_ids_b) + 1)
|
|
attention_mask = [1] * len(token_ids)
|
|
if len(token_ids) < max_seq_len:
|
|
dif = max_seq_len - len(token_ids)
|
|
token_ids += [tokenizer.pad_token_id] * dif
|
|
token_type_ids += [0] * dif
|
|
attention_mask += [0] * dif
|
|
|
|
return NewsQAEncodedInput(token_ids, token_type_ids, attention_mask,
|
|
overflow_token_ids)
|
|
|
|
|
|
def get_newsqa_examples(data, split, is_debug=False):
|
|
if is_debug:
|
|
data = data[:NUM_DEBUG]
|
|
examples = []
|
|
for para in data:
|
|
context = para['context']
|
|
qa = para['qa']
|
|
qa_id = qa['qid']
|
|
question = qa['question']
|
|
start_char_pos = None
|
|
train_answer = None
|
|
val_answer = []
|
|
|
|
is_impossible = qa['is_impossible'] if 'is_impossible' in qa else False
|
|
if not is_impossible:
|
|
answers = qa['detected_answers']
|
|
spans = sorted(
|
|
[span for spans in answers for span in spans['char_spans']])
|
|
if split == 'train':
|
|
train_answer = context[spans[0][0]:spans[0][1] + 1]
|
|
start_char_pos = spans[0][0]
|
|
else:
|
|
val_answer = [{
|
|
'text': context[spans[i][0]:spans[i][1] + 1],
|
|
'answer_start': spans[i][0]
|
|
} for i in range(len(spans))]
|
|
|
|
start_pos, end_pos, context_tokens = get_char_to_word_positions(
|
|
context, train_answer, start_char_pos, is_impossible)
|
|
examples.append(
|
|
NewsQAExample(qa_id, question, context, train_answer, val_answer,
|
|
start_pos, end_pos, context_tokens, is_impossible))
|
|
return examples
|
|
|
|
|
|
def process_newsqa_dataset(data,
|
|
split,
|
|
tokenizer,
|
|
max_seq_len,
|
|
max_query_len,
|
|
trunc_stride,
|
|
cache_dir='',
|
|
client_id=None,
|
|
pretrain=False,
|
|
is_debug=False,
|
|
**kwargs):
|
|
if pretrain:
|
|
return process_newsqa_dataset_for_pretrain(data, split, tokenizer,
|
|
max_seq_len, cache_dir,
|
|
client_id, is_debug)
|
|
|
|
save_dir = osp.join(cache_dir, 'train', str(client_id))
|
|
cache_file = osp.join(save_dir, split + '.pt')
|
|
if osp.exists(cache_file):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_file))
|
|
cache_data = torch.load(cache_file)
|
|
examples = cache_data['examples']
|
|
encoded_inputs = cache_data['encoded_inputs']
|
|
else:
|
|
examples = get_newsqa_examples(data, split, is_debug)
|
|
unique_id = 1000000000
|
|
encoded_inputs = []
|
|
for example_idx, example in enumerate(examples):
|
|
if split == 'train' and not example.is_impossible:
|
|
start_pos = example.start_position
|
|
end_pos = example.end_position
|
|
actual_answer = ' '.join(
|
|
example.context_tokens[start_pos:(end_pos + 1)])
|
|
cleaned_answer = ' '.join(example.train_answer.strip().split())
|
|
if actual_answer.find(cleaned_answer) == -1:
|
|
logger.info('Could not find answer: {} vs. {}'.format(
|
|
actual_answer, cleaned_answer))
|
|
continue
|
|
|
|
tok_to_subtok_idx = []
|
|
subtok_to_tok_idx = []
|
|
context_subtokens = []
|
|
for i, token in enumerate(example.context_tokens):
|
|
tok_to_subtok_idx.append(len(context_subtokens))
|
|
subtokens = tokenizer.tokenize(token)
|
|
for subtoken in subtokens:
|
|
subtok_to_tok_idx.append(i)
|
|
context_subtokens.append(subtoken)
|
|
|
|
if split == 'train' and not example.is_impossible:
|
|
subtoken_start_pos = tok_to_subtok_idx[example.start_position]
|
|
if example.end_position < len(example.context_tokens) - 1:
|
|
subtoken_end_pos = tok_to_subtok_idx[example.end_position +
|
|
1] - 1
|
|
else:
|
|
subtoken_end_pos = len(context_subtokens) - 1
|
|
subtoken_start_pos, subtoken_end_pos = \
|
|
refine_subtoken_position(context_subtokens,
|
|
subtoken_start_pos,
|
|
subtoken_end_pos,
|
|
tokenizer,
|
|
example.train_answer)
|
|
|
|
truncated_context = context_subtokens
|
|
len_question = min(len(tokenizer.tokenize(example.question)),
|
|
max_query_len - 2)
|
|
added_trunc_size = max_seq_len - trunc_stride - len_question - 3
|
|
spans = []
|
|
while len(spans) * trunc_stride < len(context_subtokens):
|
|
text_a = example.question
|
|
text_b = truncated_context
|
|
encoded_input = encode(tokenizer, text_a, text_b, max_seq_len,
|
|
max_query_len, added_trunc_size)
|
|
context_start_pos = len(spans) * trunc_stride
|
|
context_len = min(
|
|
len(context_subtokens) - context_start_pos,
|
|
max_seq_len - len_question - 3)
|
|
context_end_pos = context_start_pos + context_len - 1
|
|
|
|
if tokenizer.pad_token_id in encoded_input.token_ids:
|
|
non_padded_ids = encoded_input.token_ids[:encoded_input.
|
|
token_ids.index(
|
|
tokenizer.
|
|
pad_token_id)]
|
|
else:
|
|
non_padded_ids = encoded_input.token_ids
|
|
tokens = tokenizer.convert_ids_to_tokens(non_padded_ids)
|
|
|
|
context_subtok_to_tok_idx = {}
|
|
for i in range(context_len):
|
|
context_idx = len_question + i + 2
|
|
context_subtok_to_tok_idx[context_idx] = \
|
|
subtok_to_tok_idx[context_start_pos + i]
|
|
|
|
start_pos, end_pos = 0, 0
|
|
span_is_impossible = example.is_impossible
|
|
if split == 'train' and not span_is_impossible:
|
|
# For training, if our document chunk does not contain
|
|
# an annotation we throw it out, since there is nothing
|
|
# to predict.
|
|
if subtoken_start_pos >= context_start_pos and \
|
|
subtoken_end_pos <= context_end_pos:
|
|
context_offset = len_question + 2
|
|
start_pos = \
|
|
subtoken_start_pos - context_start_pos + \
|
|
context_offset
|
|
end_pos = \
|
|
subtoken_end_pos - context_start_pos + \
|
|
context_offset
|
|
else:
|
|
start_pos = 0
|
|
end_pos = 0
|
|
span_is_impossible = True
|
|
|
|
encoded_input.start_position = start_pos
|
|
encoded_input.end_position = end_pos
|
|
encoded_input.is_impossible = span_is_impossible
|
|
|
|
# For computing metrics
|
|
encoded_input.example_index = example_idx
|
|
encoded_input.context_start_position = context_start_pos
|
|
encoded_input.context_len = context_len
|
|
encoded_input.tokens = tokens
|
|
encoded_input.context_subtok_to_tok_idx = \
|
|
context_subtok_to_tok_idx
|
|
encoded_input.is_max_context_token = {}
|
|
encoded_input.unique_id = unique_id
|
|
spans.append(encoded_input)
|
|
unique_id += 1
|
|
|
|
if encoded_input.overflow_token_ids is None:
|
|
break
|
|
truncated_context = encoded_input.overflow_token_ids
|
|
|
|
for span_idx in range(len(spans)):
|
|
for context_idx in range(spans[span_idx].context_len):
|
|
is_max_context_token = check_max_context_token(
|
|
spans, span_idx, span_idx * trunc_stride + context_idx)
|
|
idx = len_question + context_idx + 2
|
|
spans[span_idx].is_max_context_token[idx] = \
|
|
is_max_context_token
|
|
encoded_inputs.extend(spans)
|
|
|
|
if cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_file))
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
torch.save({
|
|
'examples': examples,
|
|
'encoded_inputs': encoded_inputs
|
|
}, cache_file)
|
|
|
|
token_ids = torch.LongTensor([inp.token_ids for inp in encoded_inputs])
|
|
token_type_ids = torch.LongTensor(
|
|
[inp.token_type_ids for inp in encoded_inputs])
|
|
attention_mask = torch.LongTensor(
|
|
[inp.attention_mask for inp in encoded_inputs])
|
|
start_positions = torch.LongTensor(
|
|
[inp.start_position for inp in encoded_inputs])
|
|
end_positions = torch.LongTensor(
|
|
[inp.end_position for inp in encoded_inputs])
|
|
|
|
example_indices = torch.arange(token_ids.size(0), dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': token_ids,
|
|
'token_type_ids': token_type_ids,
|
|
'attention_mask': attention_mask,
|
|
'start_positions': start_positions,
|
|
'end_positions': end_positions,
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, encoded_inputs, examples
|
|
|
|
|
|
def process_newsqa_dataset_for_pretrain(data,
|
|
split,
|
|
tokenizer,
|
|
max_seq_len,
|
|
cache_dir='',
|
|
client_id=None,
|
|
is_debug=False):
|
|
save_dir = osp.join(cache_dir, 'pretrain', str(client_id))
|
|
cache_file = osp.join(save_dir, split + '.pt')
|
|
if osp.exists(cache_file):
|
|
logger.info('Loading cache file from \'{}\''.format(cache_file))
|
|
cache_data = torch.load(cache_file)
|
|
examples = cache_data['examples']
|
|
encoded_inputs = cache_data['encoded_inputs']
|
|
else:
|
|
examples = get_newsqa_examples(data, split, is_debug)
|
|
texts = split_sent([e.context for e in examples],
|
|
eoq=tokenizer.eoq_token)
|
|
encoded_inputs = tokenizer(texts,
|
|
padding='max_length',
|
|
truncation=True,
|
|
max_length=max_seq_len,
|
|
return_tensors='pt')
|
|
num_non_padding = (encoded_inputs.input_ids !=
|
|
tokenizer.pad_token_id).sum(dim=-1)
|
|
for i, pad_idx in enumerate(num_non_padding):
|
|
encoded_inputs.input_ids[i, 0] = tokenizer.bos_token_id
|
|
encoded_inputs.input_ids[i, pad_idx - 1] = tokenizer.eos_token_id
|
|
|
|
if cache_dir:
|
|
logger.info('Saving cache file to \'{}\''.format(cache_file))
|
|
os.makedirs(save_dir, exist_ok=True)
|
|
torch.save({
|
|
'examples': examples,
|
|
'encoded_inputs': encoded_inputs
|
|
}, cache_file)
|
|
|
|
example_indices = torch.arange(encoded_inputs.input_ids.size(0),
|
|
dtype=torch.long)
|
|
dataset = DatasetDict({
|
|
'token_ids': encoded_inputs.input_ids,
|
|
'attention_mask': encoded_inputs.attention_mask,
|
|
'example_indices': example_indices
|
|
})
|
|
return dataset, encoded_inputs, examples
|