import torch import torch.nn as nn import torch.nn.functional as F # ========================= # Spatial Normalization # ========================= class SNorm(nn.Module): def __init__(self, channels, eps=1e-5): super().__init__() self.gamma = nn.Parameter(torch.ones(1, channels, 1, 1)) self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1)) self.eps = eps def forward(self, x): # normalize over node dimension mean = x.mean(dim=2, keepdim=True) var = x.var(dim=2, keepdim=True, unbiased=False) x = (x - mean) / torch.sqrt(var + self.eps) return x * self.gamma + self.beta # ========================= # Temporal Normalization # ========================= class TNorm(nn.Module): def __init__(self, num_nodes, channels, momentum=0.1, eps=1e-5): super().__init__() self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) self.register_buffer("running_mean", torch.zeros(1, channels, num_nodes, 1)) self.register_buffer("running_var", torch.ones(1, channels, num_nodes, 1)) self.momentum = momentum self.eps = eps def forward(self, x): if self.training: mean = x.mean(dim=(0, 3), keepdim=True) var = x.var(dim=(0, 3), keepdim=True, unbiased=False) # in-place update (VERY IMPORTANT) self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean) self.running_var.mul_(1 - self.momentum).add_(self.momentum * var) else: mean = self.running_mean var = self.running_var x = (x - mean) / torch.sqrt(var + self.eps) return x * self.gamma + self.beta # ========================= # STNorm WaveNet # ========================= class STNormNet(nn.Module): def __init__(self, args): super().__init__() self.blocks = args["blocks"] self.layers = args["layers"] self.dropout = args["dropout"] self.num_nodes = args["num_nodes"] self.in_dim = args["in_dim"] self.out_dim = args["out_dim"] self.channels = args["channels"] self.kernel_size = args["kernel_size"] self.use_snorm = args["snorm_bool"] self.use_tnorm = args["tnorm_bool"] self.start_conv = nn.Conv2d(self.in_dim, self.channels, kernel_size=(1, 1)) self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.snorms = nn.ModuleList() self.tnorms = nn.ModuleList() self.receptive_field = 1 for b in range(self.blocks): dilation = 1 rf_add = self.kernel_size - 1 for _ in range(self.layers): if self.use_snorm: self.snorms.append(SNorm(self.channels)) if self.use_tnorm: self.tnorms.append(TNorm(self.num_nodes, self.channels)) self.filter_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) self.gate_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) self.residual_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) self.skip_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) self.receptive_field += rf_add rf_add *= 2 dilation *= 2 self.end_conv_1 = nn.Conv2d(self.channels, self.channels, (1, 1)) self.end_conv_2 = nn.Conv2d(self.channels, self.out_dim, (1, 1)) def forward(self, input): # (B, T, N, F) -> (B, F, N, T) x = input[..., :self.in_dim].transpose(1, 3) # pad to receptive field if x.size(3) < self.receptive_field: x = F.pad(x, (self.receptive_field - x.size(3), 0, 0, 0)) x = self.start_conv(x) skip = None norm_idx = 0 for i in range(self.blocks * self.layers): residual = x # ---------- STNorm (safe fusion) ---------- if self.use_tnorm: x = x + 0.5 * self.tnorms[norm_idx](x) if self.use_snorm: x = x + 0.5 * self.snorms[norm_idx](x) norm_idx += 1 # ---------- Dilated Conv ---------- filter_out = torch.tanh(self.filter_convs[i](x)) gate_out = torch.sigmoid(self.gate_convs[i](x)) x = filter_out * gate_out # ---------- Skip (TIME SAFE) ---------- s = self.skip_convs[i](x) if skip is None: skip = s else: skip = skip[..., -s.size(3) :] + s # ---------- Residual (TIME SAFE) ---------- x = self.residual_convs[i](x) x = x + residual[..., -x.size(3) :] x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) # [B, 1, N, T] T_out = x.size(3) T_target = input.size(1) if T_out < T_target: x = F.pad(x, (T_target - T_out, 0, 0, 0)) # left pad x = x.transpose(1, 3) return x