23 lines
740 B
Python
23 lines
740 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class GumbelSoftmax(nn.Module):
|
|
def __init__(self, k=1000, hard=True):
|
|
super(GumbelSoftmax, self).__init__()
|
|
self.k = k
|
|
self.hard = hard
|
|
|
|
def forward(self, logits):
|
|
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
|
|
|
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
|
if hard:
|
|
# 生成硬掩码
|
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
|
y_hard = torch.zeros_like(logits)
|
|
y_hard.scatter_(0, indices, 1)
|
|
return torch.squeeze(y_hard, dim=-1)
|
|
return torch.squeeze(y_soft, dim=-1) |