import torch import torch.nn as nn import torch.nn.functional as F def gumbel_softmax(logits, tau=1, k=1000, hard=True): y_soft = F.gumbel_softmax(logits, tau, hard) if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K y_hard = torch.zeros_like(logits) y_hard.scatter_(0, indices, 1) return torch.squeeze(y_hard, dim=-1) return torch.squeeze(y_soft, dim=-1) class Normalize(nn.Module): def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): """ :param num_features: the number of features or channels :param eps: a value added for numerical stability :param affine: if True, RevIN has learnable affine parameters """ super(Normalize, self).__init__() self.num_features = num_features self.eps = eps self.affine = affine self.subtract_last = subtract_last self.non_norm = non_norm if self.affine: self._init_params() def forward(self, x, mode: str): if mode == 'norm': self._get_statistics(x) x = self._normalize(x) elif mode == 'denorm': x = self._denormalize(x) else: raise NotImplementedError return x def _init_params(self): # initialize RevIN params: (C,) self.affine_weight = nn.Parameter(torch.ones(self.num_features)) self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) def _get_statistics(self, x): dim2reduce = tuple(range(1, x.ndim - 1)) if self.subtract_last: self.last = x[:, -1, :].unsqueeze(1) else: self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() def _normalize(self, x): if self.non_norm: return x if self.subtract_last: x = x - self.last else: x = x - self.mean x = x / self.stdev if self.affine: x = x * self.affine_weight x = x + self.affine_bias return x def _denormalize(self, x): if self.non_norm: return x if self.affine: x = x - self.affine_bias x = x / (self.affine_weight + self.eps * self.eps) x = x * self.stdev if self.subtract_last: x = x + self.last else: x = x + self.mean return x class MultiLayerPerceptron(nn.Module): """Multi-Layer Perceptron with residual links.""" def __init__(self, input_dim, hidden_dim) -> None: super().__init__() self.fc1 = nn.Conv2d( in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) self.fc2 = nn.Conv2d( in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) self.act = nn.ReLU() self.drop = nn.Dropout(p=0.15) def forward(self, input_data: torch.Tensor) -> torch.Tensor: """ input_data (torch.Tensor): input data with shape [B, D, N] """ hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP hidden = hidden + input_data # residual return hidden