TrafficWheel/model/STEP/tsformer_components/mask.py

29 lines
881 B
Python

import random
from torch import nn
class MaskGenerator(nn.Module):
"""Mask generator."""
def __init__(self, num_tokens, mask_ratio):
super().__init__()
self.num_tokens = num_tokens
self.mask_ratio = mask_ratio
self.sort = True
def uniform_rand(self):
mask = list(range(int(self.num_tokens)))
random.shuffle(mask)
mask_len = int(self.num_tokens * self.mask_ratio)
self.masked_tokens = mask[:mask_len]
self.unmasked_tokens = mask[mask_len:]
if self.sort:
self.masked_tokens = sorted(self.masked_tokens)
self.unmasked_tokens = sorted(self.unmasked_tokens)
return self.unmasked_tokens, self.masked_tokens
def forward(self):
self.unmasked_tokens, self.masked_tokens = self.uniform_rand()
return self.unmasked_tokens, self.masked_tokens