105 lines
3.3 KiB
Python
105 lines
3.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
|
|
|
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
|
|
|
if hard:
|
|
# 生成硬掩码
|
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
|
y_hard = torch.zeros_like(logits)
|
|
y_hard.scatter_(0, indices, 1)
|
|
return torch.squeeze(y_hard, dim=-1)
|
|
return torch.squeeze(y_soft, dim=-1)
|
|
|
|
|
|
|
|
class Normalize(nn.Module):
|
|
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
|
"""
|
|
:param num_features: the number of features or channels
|
|
:param eps: a value added for numerical stability
|
|
:param affine: if True, RevIN has learnable affine parameters
|
|
"""
|
|
super(Normalize, self).__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.affine = affine
|
|
self.subtract_last = subtract_last
|
|
self.non_norm = non_norm
|
|
if self.affine:
|
|
self._init_params()
|
|
|
|
def forward(self, x, mode: str):
|
|
if mode == 'norm':
|
|
self._get_statistics(x)
|
|
x = self._normalize(x)
|
|
elif mode == 'denorm':
|
|
x = self._denormalize(x)
|
|
else:
|
|
raise NotImplementedError
|
|
return x
|
|
|
|
def _init_params(self):
|
|
# initialize RevIN params: (C,)
|
|
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
|
|
def _get_statistics(self, x):
|
|
dim2reduce = tuple(range(1, x.ndim - 1))
|
|
if self.subtract_last:
|
|
self.last = x[:, -1, :].unsqueeze(1)
|
|
else:
|
|
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
|
|
def _normalize(self, x):
|
|
if self.non_norm:
|
|
return x
|
|
if self.subtract_last:
|
|
x = x - self.last
|
|
else:
|
|
x = x - self.mean
|
|
x = x / self.stdev
|
|
if self.affine:
|
|
x = x * self.affine_weight
|
|
x = x + self.affine_bias
|
|
return x
|
|
|
|
def _denormalize(self, x):
|
|
if self.non_norm:
|
|
return x
|
|
if self.affine:
|
|
x = x - self.affine_bias
|
|
x = x / (self.affine_weight + self.eps * self.eps)
|
|
x = x * self.stdev
|
|
if self.subtract_last:
|
|
x = x + self.last
|
|
else:
|
|
x = x + self.mean
|
|
return x
|
|
|
|
|
|
class MultiLayerPerceptron(nn.Module):
|
|
"""Multi-Layer Perceptron with residual links."""
|
|
|
|
def __init__(self, input_dim, hidden_dim) -> None:
|
|
super().__init__()
|
|
self.fc1 = nn.Conv2d(
|
|
in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
|
self.fc2 = nn.Conv2d(
|
|
in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
|
self.act = nn.ReLU()
|
|
self.drop = nn.Dropout(p=0.15)
|
|
|
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
input_data (torch.Tensor): input data with shape [B, D, N]
|
|
"""
|
|
|
|
hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP
|
|
hidden = hidden + input_data # residual
|
|
return hidden |