TrafficWheel/model/STNorm/STNorm.py

140 lines
5.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class SNorm(nn.Module):
def __init__(self, channels):
super().__init__()
self.beta = nn.Parameter(torch.zeros(channels))
self.gamma = nn.Parameter(torch.ones(channels))
def forward(self, x):
x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5
return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)
class TNorm(nn.Module):
def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1):
super().__init__()
self.track_running_stats = track_running_stats
self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1))
self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1))
self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1))
self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1))
self.momentum = momentum
def forward(self, x):
if self.track_running_stats:
mean = x.mean((0, 3), keepdims=True)
var = x.var((0, 3), keepdims=True, unbiased=False)
if self.training:
n = x.shape[3] * x.shape[0]
with torch.no_grad():
self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean
self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var
else:
mean = self.running_mean
var = self.running_var
else:
mean = x.mean(3, keepdims=True)
var = x.var(3, keepdims=True, unbiased=True)
x_norm = (x - mean) / (var + 1e-5) ** 0.5
return x_norm * self.gamma + self.beta
class stnorm(nn.Module):
def __init__(self, args):
super().__init__()
self.dropout = args["dropout"]
self.blocks = args["blocks"]
self.layers = args["layers"]
self.snorm_bool = args["snorm_bool"]
self.tnorm_bool = args["tnorm_bool"]
self.num_nodes = args["num_nodes"]
in_dim = args["in_dim"]
out_dim = args["out_dim"]
channels = args["channels"]
kernel_size = args["kernel_size"]
# 初始化卷积层
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1))
# 初始化模块列表
self.filter_convs = nn.ModuleList()
self.gate_convs = nn.ModuleList()
self.residual_convs = nn.ModuleList()
self.skip_convs = nn.ModuleList()
self.sn = nn.ModuleList() if self.snorm_bool else None
self.tn = nn.ModuleList() if self.tnorm_bool else None
# 计算感受野
self.receptive_field = 1
additional_scope = kernel_size - 1
# 构建网络层
for b in range(self.blocks):
new_dilation = 1
for i in range(self.layers):
if self.tnorm_bool:
self.tn.append(TNorm(self.num_nodes, channels))
if self.snorm_bool:
self.sn.append(SNorm(channels))
# 膨胀卷积 - 直接使用channels作为输入通道不再拼接多个特征
self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=(1, kernel_size), dilation=new_dilation))
self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
kernel_size=(1, kernel_size), dilation=new_dilation))
# 残差连接和跳跃连接
self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
# 更新感受野
self.receptive_field += additional_scope
additional_scope *= 2
new_dilation *= 2
# 输出层
self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True)
self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True)
def forward(self, input):
# 输入处理与GWN保持一致 (bs, features, n_nodes, n_timesteps)
x = input[..., 0:1].transpose(1, 3)
# 处理感受野
in_len = x.size(3)
if in_len < self.receptive_field:
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
# 起始卷积
x = self.start_conv(x)
skip = 0
# WaveNet层
for i in range(self.blocks * self.layers):
residual = x
# 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接)
x_norm = x
if self.tnorm_bool:
x_norm += self.tn[i](x)
if self.snorm_bool:
x_norm += self.sn[i](x)
# 膨胀卷积
filter = torch.tanh(self.filter_convs[i](x_norm))
gate = torch.sigmoid(self.gate_convs[i](x_norm))
x = filter * gate
# 跳跃连接
s = self.skip_convs[i](x)
skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0)
# 残差连接
x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):]
# 输出处理
x = F.relu(skip)
x = F.relu(self.end_conv_1(x))
x = self.end_conv_2(x)
return x