TrafficWheel/model/REPST/normalizer.py

16 lines
437 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard)
if hard:
# 生成硬掩码
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
y_hard = torch.zeros_like(logits)
y_hard.scatter_(0, indices, 1)
return torch.squeeze(y_hard, dim=-1)
return torch.squeeze(y_soft, dim=-1)