140 lines
5.7 KiB
Python
140 lines
5.7 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
class SNorm(nn.Module):
|
||
def __init__(self, channels):
|
||
super().__init__()
|
||
self.beta = nn.Parameter(torch.zeros(channels))
|
||
self.gamma = nn.Parameter(torch.ones(channels))
|
||
|
||
def forward(self, x):
|
||
x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5
|
||
return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)
|
||
|
||
class TNorm(nn.Module):
|
||
def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1):
|
||
super().__init__()
|
||
self.track_running_stats = track_running_stats
|
||
self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1))
|
||
self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1))
|
||
self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1))
|
||
self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1))
|
||
self.momentum = momentum
|
||
|
||
def forward(self, x):
|
||
if self.track_running_stats:
|
||
mean = x.mean((0, 3), keepdims=True)
|
||
var = x.var((0, 3), keepdims=True, unbiased=False)
|
||
if self.training:
|
||
n = x.shape[3] * x.shape[0]
|
||
with torch.no_grad():
|
||
self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean
|
||
self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var
|
||
else:
|
||
mean = self.running_mean
|
||
var = self.running_var
|
||
else:
|
||
mean = x.mean(3, keepdims=True)
|
||
var = x.var(3, keepdims=True, unbiased=True)
|
||
x_norm = (x - mean) / (var + 1e-5) ** 0.5
|
||
return x_norm * self.gamma + self.beta
|
||
|
||
class stnorm(nn.Module):
|
||
def __init__(self, args):
|
||
super().__init__()
|
||
self.dropout = args["dropout"]
|
||
self.blocks = args["blocks"]
|
||
self.layers = args["layers"]
|
||
self.snorm_bool = args["snorm_bool"]
|
||
self.tnorm_bool = args["tnorm_bool"]
|
||
self.num_nodes = args["num_nodes"]
|
||
in_dim = args["in_dim"]
|
||
out_dim = args["out_dim"]
|
||
channels = args["channels"]
|
||
kernel_size = args["kernel_size"]
|
||
|
||
# 初始化卷积层
|
||
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1))
|
||
|
||
# 初始化模块列表
|
||
self.filter_convs = nn.ModuleList()
|
||
self.gate_convs = nn.ModuleList()
|
||
self.residual_convs = nn.ModuleList()
|
||
self.skip_convs = nn.ModuleList()
|
||
self.sn = nn.ModuleList() if self.snorm_bool else None
|
||
self.tn = nn.ModuleList() if self.tnorm_bool else None
|
||
|
||
# 计算感受野
|
||
self.receptive_field = 1
|
||
additional_scope = kernel_size - 1
|
||
|
||
# 构建网络层
|
||
for b in range(self.blocks):
|
||
new_dilation = 1
|
||
for i in range(self.layers):
|
||
if self.tnorm_bool:
|
||
self.tn.append(TNorm(self.num_nodes, channels))
|
||
if self.snorm_bool:
|
||
self.sn.append(SNorm(channels))
|
||
|
||
# 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征
|
||
self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
|
||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||
self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
|
||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||
|
||
# 残差连接和跳跃连接
|
||
self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
|
||
self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
|
||
|
||
# 更新感受野
|
||
self.receptive_field += additional_scope
|
||
additional_scope *= 2
|
||
new_dilation *= 2
|
||
|
||
# 输出层
|
||
self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True)
|
||
self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True)
|
||
|
||
def forward(self, input):
|
||
# 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps)
|
||
x = input[..., 0:1].transpose(1, 3)
|
||
|
||
# 处理感受野
|
||
in_len = x.size(3)
|
||
if in_len < self.receptive_field:
|
||
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
|
||
|
||
# 起始卷积
|
||
x = self.start_conv(x)
|
||
skip = 0
|
||
|
||
# WaveNet层
|
||
for i in range(self.blocks * self.layers):
|
||
residual = x
|
||
|
||
# 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接)
|
||
x_norm = x
|
||
if self.tnorm_bool:
|
||
x_norm += self.tn[i](x)
|
||
if self.snorm_bool:
|
||
x_norm += self.sn[i](x)
|
||
|
||
# 膨胀卷积
|
||
filter = torch.tanh(self.filter_convs[i](x_norm))
|
||
gate = torch.sigmoid(self.gate_convs[i](x_norm))
|
||
x = filter * gate
|
||
|
||
# 跳跃连接
|
||
s = self.skip_convs[i](x)
|
||
skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0)
|
||
|
||
# 残差连接
|
||
x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):]
|
||
|
||
# 输出处理
|
||
x = F.relu(skip)
|
||
x = F.relu(self.end_conv_1(x))
|
||
x = self.end_conv_2(x)
|
||
return x |