FS-TFP/federatedscope/nlp/hetero_tasks/dataloader/datacollator.py

264 lines
10 KiB
Python

import math
import numpy as np
import torch
from numpy.random import permutation, poisson
class DataCollatorForMLM(object):
def __init__(self, tokenizer, mlm_probability=0.15):
self.tokenizer = tokenizer
self.mlm_probability = mlm_probability
def __call__(self, examples):
""" Prepare masked tokens inputs/labels for masked language
modeling: 80% MASK, 10% random, 10% original. """
examples = {
k: torch.stack([x[k] for x in examples])
for k in examples[0].keys()
}
token_ids = examples['token_ids']
attention_mask = examples['attention_mask']
labels = token_ids.clone()
# We sample a few tokens in each sequence for masked-LM training
# (with probability self.mlm_probability defaults to 0.15 in
# Bert/RoBERTa)
probability_matrix = torch.full(labels.shape, self.mlm_probability)
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(
val, already_has_special_tokens=True)
for val in labels.tolist()
]
probability_matrix.masked_fill_(torch.tensor(special_tokens_mask,
dtype=torch.bool),
value=0.0)
if self.tokenizer._pad_token is not None:
padding_mask = labels.eq(self.tokenizer.pad_token_id)
probability_matrix.masked_fill_(padding_mask, value=0.0)
masked_indices = torch.bernoulli(probability_matrix).bool()
labels[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with
# tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(
labels.shape, 0.8)).bool() & masked_indices
token_ids[indices_replaced] = self.tokenizer.convert_tokens_to_ids(
self.tokenizer.mask_token)
# 10% of the time, we replace masked input tokens with random word
indices_random = \
torch.bernoulli(torch.full(labels.shape, 0.5)).bool() & \
masked_indices & ~indices_replaced
random_words = torch.randint(len(self.tokenizer),
labels.shape,
dtype=torch.long)
token_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input
# tokens unchanged
return {
'token_ids': token_ids,
'attention_mask': attention_mask,
'labels': labels,
'example_indices': examples['example_indices']
}
class DataCollatorForDenoisingReconstrcution(object):
"""Data collator used denoising language modeling task in BART.
The implementation is based on
https://github.com/pytorch/fairseq/blob/
1bba712622b8ae4efb3eb793a8a40da386fe11d0/fairseq/data/denoising_dataset.py.
The default paramters is based on BART paper
https://arxiv.org/abs/1910.13461.
"""
def __init__(self,
tokenizer,
mask_ratio=0.3,
poisson_lambda=3.0,
permutate_sentence_ratio=1.0):
self.tokenizer = tokenizer
self.mask_ratio = mask_ratio
self.poisson_lambda = poisson_lambda
self.permutate_sentence_ratio = permutate_sentence_ratio
def __call__(self, examples):
examples = {
k: torch.stack([x[k] for x in examples])
for k in examples[0].keys()
}
token_ids = examples['token_ids'].numpy()
attention_mask = examples['attention_mask'].numpy()
labels = token_ids.copy()
do_permutate = False
if self.permutate_sentence_ratio > 0.0:
permute_sent = self.permutate_sentences(token_ids[:, 1:])
for i, s in enumerate(permute_sent):
token_ids[i, 1:] = s
do_permutate = True
if self.mask_ratio:
token_ids, _ = self.add_whole_word_mask(token_ids, do_permutate)
num_non_padding = np.sum(token_ids != self.tokenizer.pad_token_id,
axis=-1)
for i in range(len(attention_mask)):
attention_mask[i][num_non_padding[i]:] = 0
token_ids = torch.from_numpy(token_ids)
attention_mask = torch.from_numpy(attention_mask)
labels = torch.from_numpy(labels)
return {
'token_ids': token_ids,
'attention_mask': attention_mask,
'labels': labels,
'example_indices': examples['example_indices']
}
def permutate_sentences(self, inputs):
results = inputs.copy()
for i in range(inputs.shape[0]):
full_stops = (inputs[i] == self.tokenizer.eoq_token_id) | (
inputs[i] == self.tokenizer.eos_token_id)
full_stops = full_stops[None, :]
sentence_ends = np.argwhere(full_stops[:, 1:] *
~full_stops[:, :-1])
if len(sentence_ends) == 0:
continue
sentence_ends[:, 1] += 2
num_sentences = np.unique(sentence_ends[:, 0],
return_counts=True)[1]
num_to_permute = np.ceil(
(num_sentences * 2 * self.permutate_sentence_ratio) /
2.0).astype(int)
sentence_ends = np.split(
sentence_ends[:, 1],
np.unique(sentence_ends[:, 0], return_index=True)[1][1:])
substitutions = np.random.permutation(
num_sentences[0])[:num_to_permute[0]]
ordering = np.arange(0, num_sentences[0])
ordering[substitutions] = substitutions[np.random.permutation(
num_to_permute[0])]
index = 0
for j in ordering:
sentence = inputs[i, (
sentence_ends[0][j -
1] if j > 0 else 0):sentence_ends[0][j]]
results[i, index:index + sentence.shape[0]] = sentence
index += sentence.shape[0]
num_non_padding = np.sum(results != self.tokenizer.pad_token_id,
axis=-1)
eos_indices = np.where(results == self.tokenizer.eos_token_id)[1]
for i, (idx1, idx2) in enumerate(zip(eos_indices, num_non_padding)):
results[i][idx1] = self.tokenizer.eoq_token_id
results[i][idx2 - 1] = self.tokenizer.eos_token_id
return results
def add_whole_word_mask(self, inputs, do_permutate):
labels = inputs.copy()
inputs = inputs.copy()
special_tokens_mask = [
self.tokenizer.get_special_tokens_mask(
val, already_has_special_tokens=True)
for val in labels.tolist()
]
special_tokens_mask = np.array(special_tokens_mask, dtype=bool)
# determine how many tokens we need to mask in total
is_token = ~(labels == self.tokenizer.pad_token_id) & \
~special_tokens_mask
num_to_mask = int(
math.ceil(is_token.astype(float).sum() * self.mask_ratio))
if num_to_mask == 0:
return inputs, labels
# generate a sufficient number of span lengths
lengths = poisson(lam=self.poisson_lambda, size=(num_to_mask, ))
while np.cumsum(lengths, 0)[-1] < num_to_mask:
lengths = np.concatenate([
lengths,
poisson(lam=self.poisson_lambda, size=(num_to_mask, ))
])
# remove all spans of length 0
# Note that BART inserts additional mask tokens where length == 0,
# which we do not implement for now as it adds additional complexity
lengths = lengths[lengths > 0]
# trim to about num_to_mask tokens
idx = np.argmin(np.abs(np.cumsum(lengths, 0) - num_to_mask)) + 1
lengths = lengths[:idx + 1]
# select span start indices
token_indices = np.argwhere(is_token == 1)
span_starts = permutation(token_indices.shape[0])[:lengths.shape[0]]
# prepare mask
masked_indices = np.array(token_indices[span_starts])
mask = np.full_like(labels, fill_value=False)
# mask span start indices
for mi in masked_indices:
mask[tuple(mi)] = True
lengths -= 1
# fill up spans
max_index = labels.shape[1] - 1
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)
while np.any(remaining):
masked_indices[remaining, 1] += 1
for mi in masked_indices:
mask[tuple(mi)] = True
lengths -= 1
remaining = (lengths > 0) & (masked_indices[:, 1] < max_index)
# place the mask tokens
mask[np.where(special_tokens_mask)] = False
inputs[np.where(mask)] = self.tokenizer.mask_token_id
if not do_permutate:
labels[np.where(mask)] = -100
else:
labels[np.where(special_tokens_mask)] = -100
# remove mask tokens that are not starts of spans
to_remove = (mask == 1) & np.roll((mask == 1), 1, 1)
new_inputs = np.full_like(labels,
fill_value=self.tokenizer.pad_token_id)
# splits = list(map(lambda x: x.reshape(-1), np.split(inputs_copy,
# indices_or_sections=2, axis=0))
for i, example in enumerate(
np.split(inputs,
indices_or_sections=new_inputs.shape[0],
axis=0)):
new_example = example[0][~to_remove[i]]
new_inputs[i, 0:new_example.shape[0]] = new_example
# batching now fixed
return new_inputs, labels
class DataCollator(object):
def __init__(self,
tokenizer,
mlm_probability=0.15,
mask_ratio=0.3,
poisson_lambda=3.0,
permutate_sentence_ratio=1.0):
self.mlm_collator = DataCollatorForMLM(tokenizer, mlm_probability)
self.denoise_collator = DataCollatorForDenoisingReconstrcution(
tokenizer, mask_ratio, poisson_lambda, permutate_sentence_ratio)
def __call__(self, examples):
mlm_results = self.mlm_collator(examples)
denoise_results = self.denoise_collator(examples)
return {'mlm': mlm_results, 'denoise': denoise_results}