29 lines
881 B
Python
29 lines
881 B
Python
import random
|
|
|
|
from torch import nn
|
|
|
|
|
|
class MaskGenerator(nn.Module):
|
|
"""Mask generator."""
|
|
|
|
def __init__(self, num_tokens, mask_ratio):
|
|
super().__init__()
|
|
self.num_tokens = num_tokens
|
|
self.mask_ratio = mask_ratio
|
|
self.sort = True
|
|
|
|
def uniform_rand(self):
|
|
mask = list(range(int(self.num_tokens)))
|
|
random.shuffle(mask)
|
|
mask_len = int(self.num_tokens * self.mask_ratio)
|
|
self.masked_tokens = mask[:mask_len]
|
|
self.unmasked_tokens = mask[mask_len:]
|
|
if self.sort:
|
|
self.masked_tokens = sorted(self.masked_tokens)
|
|
self.unmasked_tokens = sorted(self.unmasked_tokens)
|
|
return self.unmasked_tokens, self.masked_tokens
|
|
|
|
def forward(self):
|
|
self.unmasked_tokens, self.masked_tokens = self.uniform_rand()
|
|
return self.unmasked_tokens, self.masked_tokens
|