import torch import torch.nn as nn import torch.nn.functional as F class GumbelSoftmax(nn.Module): def __init__(self, k=1000, hard=True): super(GumbelSoftmax, self).__init__() self.k = k self.hard = hard def forward(self, logits): return self.gumbel_softmax(logits, 1, self.k, self.hard) def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): y_soft = F.gumbel_softmax(logits, tau, hard) if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K y_hard = torch.zeros_like(logits) y_hard.scatter_(0, indices, 1) return torch.squeeze(y_hard, dim=-1) return torch.squeeze(y_soft, dim=-1)