TrafficWheel/model/STNorm/STNorm.py

149 lines
5.2 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
# =========================
# Spatial Normalization
# =========================
class SNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, channels, 1, 1))
self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1))
self.eps = eps
def forward(self, x):
# normalize over node dimension
mean = x.mean(dim=2, keepdim=True)
var = x.var(dim=2, keepdim=True, unbiased=False)
x = (x - mean) / torch.sqrt(var + self.eps)
return x * self.gamma + self.beta
# =========================
# Temporal Normalization
# =========================
class TNorm(nn.Module):
def __init__(self, num_nodes, channels, momentum=0.1, eps=1e-5):
super().__init__()
self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1))
self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1))
self.register_buffer("running_mean", torch.zeros(1, channels, num_nodes, 1))
self.register_buffer("running_var", torch.ones(1, channels, num_nodes, 1))
self.momentum = momentum
self.eps = eps
def forward(self, x):
if self.training:
mean = x.mean(dim=(0, 3), keepdim=True)
var = x.var(dim=(0, 3), keepdim=True, unbiased=False)
# in-place update (VERY IMPORTANT)
self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
else:
mean = self.running_mean
var = self.running_var
x = (x - mean) / torch.sqrt(var + self.eps)
return x * self.gamma + self.beta
# =========================
# STNorm WaveNet
# =========================
class STNormNet(nn.Module):
def __init__(self, args):
super().__init__()
self.blocks = args["blocks"]
self.layers = args["layers"]
self.dropout = args["dropout"]
self.num_nodes = args["num_nodes"]
self.in_dim = args["in_dim"]
self.out_dim = args["out_dim"]
self.channels = args["channels"]
self.kernel_size = args["kernel_size"]
self.use_snorm = args["snorm_bool"]
self.use_tnorm = args["tnorm_bool"]
self.start_conv = nn.Conv2d(self.in_dim, self.channels, kernel_size=(1, 1))
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.snorms = nn.ModuleList()
self.tnorms = nn.ModuleList()
self.receptive_field = 1
for b in range(self.blocks):
dilation = 1
rf_add = self.kernel_size - 1
for _ in range(self.layers):
if self.use_snorm:
self.snorms.append(SNorm(self.channels))
if self.use_tnorm:
self.tnorms.append(TNorm(self.num_nodes, self.channels))
self.filter_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation))
self.gate_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation))
self.residual_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1)))
self.skip_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1)))
self.receptive_field += rf_add
rf_add *= 2
dilation *= 2
self.end_conv_1 = nn.Conv2d(self.channels, self.channels, (1, 1))
self.end_conv_2 = nn.Conv2d(self.channels, self.out_dim, (1, 1))
def forward(self, input):
# (B, T, N, F) -> (B, F, N, T)
x = input[..., :self.in_dim].transpose(1, 3)
# pad to receptive field
if x.size(3) < self.receptive_field:
x = F.pad(x, (self.receptive_field - x.size(3), 0, 0, 0))
x = self.start_conv(x)
skip = None
norm_idx = 0
for i in range(self.blocks * self.layers):
residual = x
# ---------- STNorm (safe fusion) ----------
if self.use_tnorm:
x = x + 0.5 * self.tnorms[norm_idx](x)
if self.use_snorm:
x = x + 0.5 * self.snorms[norm_idx](x)
norm_idx += 1
# ---------- Dilated Conv ----------
filter_out = torch.tanh(self.filter_convs[i](x))
gate_out = torch.sigmoid(self.gate_convs[i](x))
x = filter_out * gate_out
# ---------- Skip (TIME SAFE) ----------
s = self.skip_convs[i](x)
if skip is None:
skip = s
else:
skip = skip[..., -s.size(3) :] + s
# ---------- Residual (TIME SAFE) ----------
x = self.residual_convs[i](x)
x = x + residual[..., -x.size(3) :]
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x) # [B, 1, N, T]
T_out = x.size(3)
T_target = input.size(1)
if T_out < T_target:
x = F.pad(x, (T_target - T_out, 0, 0, 0)) # left pad
x = x.transpose(1, 3)
return x