import torch import torch.nn as nn import torch.nn.functional as F def gumbel_softmax(logits, tau=1, k=1000, hard=True): y_soft = F.gumbel_softmax(logits, tau, hard) if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K y_hard = torch.zeros_like(logits) y_hard.scatter_(0, indices, 1) return torch.squeeze(y_hard, dim=-1) return torch.squeeze(y_soft, dim=-1)