import torch import torch.nn as nn import torch.nn.functional as F class SNorm(nn.Module): def __init__(self, channels): super().__init__() self.beta = nn.Parameter(torch.zeros(channels)) self.gamma = nn.Parameter(torch.ones(channels)) def forward(self, x): x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5 return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1) class TNorm(nn.Module): def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1): super().__init__() self.track_running_stats = track_running_stats self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1)) self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1)) self.momentum = momentum def forward(self, x): if self.track_running_stats: mean = x.mean((0, 3), keepdims=True) var = x.var((0, 3), keepdims=True, unbiased=False) if self.training: n = x.shape[3] * x.shape[0] with torch.no_grad(): self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var else: mean = self.running_mean var = self.running_var else: mean = x.mean(3, keepdims=True) var = x.var(3, keepdims=True, unbiased=True) x_norm = (x - mean) / (var + 1e-5) ** 0.5 return x_norm * self.gamma + self.beta class stnorm(nn.Module): def __init__(self, args): super().__init__() self.dropout = args["dropout"] self.blocks = args["blocks"] self.layers = args["layers"] self.snorm_bool = args["snorm_bool"] self.tnorm_bool = args["tnorm_bool"] self.num_nodes = args["num_nodes"] in_dim = args["in_dim"] out_dim = args["out_dim"] channels = args["channels"] kernel_size = args["kernel_size"] # 初始化卷积层 self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1)) # 初始化模块列表 self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() self.sn = nn.ModuleList() if self.snorm_bool else None self.tn = nn.ModuleList() if self.tnorm_bool else None # 计算感受野 self.receptive_field = 1 additional_scope = kernel_size - 1 # 构建网络层 for b in range(self.blocks): new_dilation = 1 for i in range(self.layers): if self.tnorm_bool: self.tn.append(TNorm(self.num_nodes, channels)) if self.snorm_bool: self.sn.append(SNorm(channels)) # 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征 self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation)) self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, kernel_size), dilation=new_dilation)) # 残差连接和跳跃连接 self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) # 更新感受野 self.receptive_field += additional_scope additional_scope *= 2 new_dilation *= 2 # 输出层 self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True) self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) def forward(self, input): # 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps) x = input[..., 0:1].transpose(1, 3) # 处理感受野 in_len = x.size(3) if in_len < self.receptive_field: x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) # 起始卷积 x = self.start_conv(x) skip = 0 # WaveNet层 for i in range(self.blocks * self.layers): residual = x # 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接) x_norm = x if self.tnorm_bool: x_norm += self.tn[i](x) if self.snorm_bool: x_norm += self.sn[i](x) # 膨胀卷积 filter = torch.tanh(self.filter_convs[i](x_norm)) gate = torch.sigmoid(self.gate_convs[i](x_norm)) x = filter * gate # 跳跃连接 s = self.skip_convs[i](x) skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) # 残差连接 x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):] # 输出处理 x = F.relu(skip) x = F.relu(self.end_conv_1(x)) x = self.end_conv_2(x) return x