REPST #3
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: AirQuality
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 35 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 6 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 6 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 35
|
||||
in_dim: 6
|
||||
out_dim: 6
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: BJTaxi-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 1024 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 1024
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: BJTaxi-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 1024 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 1024
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ model:
|
|||
tnorm_bool: True
|
||||
num_nodes: 207
|
||||
in_dim: 1
|
||||
out_dim: 24
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: NYCBike-InFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 128 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 128
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: NYCBike-OutFlow
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 128 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 128
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: PEMS-BAY
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 325 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 325
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ basic:
|
|||
dataset: SolarEnergy
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: MTGNN
|
||||
model: STNorm
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
|
|
@ -19,28 +19,16 @@ data:
|
|||
val_ratio: 0.2
|
||||
|
||||
model:
|
||||
gcn_true: True # 是否使用图卷积网络 (bool)
|
||||
buildA_true: True # 是否动态构建邻接矩阵 (bool)
|
||||
subgraph_size: 20 # 子图大小 (int)
|
||||
num_nodes: 137 # 节点数量 (int)
|
||||
node_dim: 40 # 节点嵌入维度 (int)
|
||||
dilation_exponential: 1 # 膨胀卷积指数 (int)
|
||||
conv_channels: 32 # 卷积通道数 (int)
|
||||
residual_channels: 32 # 残差通道数 (int)
|
||||
skip_channels: 64 # 跳跃连接通道数 (int)
|
||||
end_channels: 128 # 输出层通道数 (int)
|
||||
seq_len: 24 # 输入序列长度 (int)
|
||||
in_dim: 1 # 输入特征维度 (int)
|
||||
out_len: 24 # 输出序列长度 (int)
|
||||
out_dim: 1 # 输出预测维度 (int)
|
||||
layers: 3 # 模型层数 (int)
|
||||
propalpha: 0.05 # 图传播参数alpha (float)
|
||||
tanhalpha: 3 # tanh激活参数alpha (float)
|
||||
layer_norm_affline: True # 层归一化是否使用affine变换 (bool)
|
||||
gcn_depth: 2 # 图卷积深度 (int)
|
||||
dropout: 0.3 # dropout率 (float)
|
||||
predefined_A: null # 预定义邻接矩阵 (optional, None)
|
||||
static_feat: null # 静态特征 (optional, None)
|
||||
dropout: 0.2
|
||||
blocks: 2
|
||||
layers: 2
|
||||
snorm_bool: True
|
||||
tnorm_bool: True
|
||||
num_nodes: 137
|
||||
in_dim: 1
|
||||
out_dim: 1
|
||||
channels: 32
|
||||
kernel_size: 2
|
||||
|
||||
train:
|
||||
batch_size: 64
|
||||
|
|
|
|||
|
|
@ -2,139 +2,147 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# =========================
|
||||
# Spatial Normalization
|
||||
# =========================
|
||||
class SNorm(nn.Module):
|
||||
def __init__(self, channels):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
self.beta = nn.Parameter(torch.zeros(channels))
|
||||
self.gamma = nn.Parameter(torch.ones(channels))
|
||||
self.gamma = nn.Parameter(torch.ones(1, channels, 1, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1))
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5
|
||||
return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1)
|
||||
# normalize over node dimension
|
||||
mean = x.mean(dim=2, keepdim=True)
|
||||
var = x.var(dim=2, keepdim=True, unbiased=False)
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x * self.gamma + self.beta
|
||||
|
||||
|
||||
# =========================
|
||||
# Temporal Normalization
|
||||
# =========================
|
||||
class TNorm(nn.Module):
|
||||
def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1):
|
||||
def __init__(self, num_nodes, channels, momentum=0.1, eps=1e-5):
|
||||
super().__init__()
|
||||
self.track_running_stats = track_running_stats
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1))
|
||||
self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1))
|
||||
self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1))
|
||||
self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1))
|
||||
self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1))
|
||||
self.register_buffer("running_mean", torch.zeros(1, channels, num_nodes, 1))
|
||||
self.register_buffer("running_var", torch.ones(1, channels, num_nodes, 1))
|
||||
self.momentum = momentum
|
||||
self.eps = eps
|
||||
|
||||
def forward(self, x):
|
||||
if self.track_running_stats:
|
||||
mean = x.mean((0, 3), keepdims=True)
|
||||
var = x.var((0, 3), keepdims=True, unbiased=False)
|
||||
if self.training:
|
||||
n = x.shape[3] * x.shape[0]
|
||||
with torch.no_grad():
|
||||
self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean
|
||||
self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var
|
||||
else:
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
if self.training:
|
||||
mean = x.mean(dim=(0, 3), keepdim=True)
|
||||
var = x.var(dim=(0, 3), keepdim=True, unbiased=False)
|
||||
# in-place update (VERY IMPORTANT)
|
||||
self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean)
|
||||
self.running_var.mul_(1 - self.momentum).add_(self.momentum * var)
|
||||
else:
|
||||
mean = x.mean(3, keepdims=True)
|
||||
var = x.var(3, keepdims=True, unbiased=True)
|
||||
x_norm = (x - mean) / (var + 1e-5) ** 0.5
|
||||
return x_norm * self.gamma + self.beta
|
||||
mean = self.running_mean
|
||||
var = self.running_var
|
||||
|
||||
class stnorm(nn.Module):
|
||||
x = (x - mean) / torch.sqrt(var + self.eps)
|
||||
return x * self.gamma + self.beta
|
||||
|
||||
|
||||
# =========================
|
||||
# STNorm WaveNet
|
||||
# =========================
|
||||
class STNormNet(nn.Module):
|
||||
def __init__(self, args):
|
||||
super().__init__()
|
||||
self.dropout = args["dropout"]
|
||||
self.blocks = args["blocks"]
|
||||
self.layers = args["layers"]
|
||||
self.snorm_bool = args["snorm_bool"]
|
||||
self.tnorm_bool = args["tnorm_bool"]
|
||||
self.dropout = args["dropout"]
|
||||
self.num_nodes = args["num_nodes"]
|
||||
in_dim = args["in_dim"]
|
||||
out_dim = args["out_dim"]
|
||||
channels = args["channels"]
|
||||
kernel_size = args["kernel_size"]
|
||||
|
||||
# 初始化卷积层
|
||||
self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1))
|
||||
|
||||
# 初始化模块列表
|
||||
self.in_dim = args["in_dim"]
|
||||
self.out_dim = args["out_dim"]
|
||||
self.channels = args["channels"]
|
||||
self.kernel_size = args["kernel_size"]
|
||||
|
||||
self.use_snorm = args["snorm_bool"]
|
||||
self.use_tnorm = args["tnorm_bool"]
|
||||
|
||||
self.start_conv = nn.Conv2d(self.in_dim, self.channels, kernel_size=(1, 1))
|
||||
|
||||
self.filter_convs = nn.ModuleList()
|
||||
self.gate_convs = nn.ModuleList()
|
||||
self.residual_convs = nn.ModuleList()
|
||||
self.skip_convs = nn.ModuleList()
|
||||
self.sn = nn.ModuleList() if self.snorm_bool else None
|
||||
self.tn = nn.ModuleList() if self.tnorm_bool else None
|
||||
|
||||
# 计算感受野
|
||||
self.snorms = nn.ModuleList()
|
||||
self.tnorms = nn.ModuleList()
|
||||
|
||||
self.receptive_field = 1
|
||||
additional_scope = kernel_size - 1
|
||||
|
||||
# 构建网络层
|
||||
for b in range(self.blocks):
|
||||
new_dilation = 1
|
||||
for i in range(self.layers):
|
||||
if self.tnorm_bool:
|
||||
self.tn.append(TNorm(self.num_nodes, channels))
|
||||
if self.snorm_bool:
|
||||
self.sn.append(SNorm(channels))
|
||||
|
||||
# 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征
|
||||
self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
|
||||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||||
self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels,
|
||||
kernel_size=(1, kernel_size), dilation=new_dilation))
|
||||
|
||||
# 残差连接和跳跃连接
|
||||
self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
|
||||
self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1)))
|
||||
|
||||
# 更新感受野
|
||||
self.receptive_field += additional_scope
|
||||
additional_scope *= 2
|
||||
new_dilation *= 2
|
||||
dilation = 1
|
||||
rf_add = self.kernel_size - 1
|
||||
for _ in range(self.layers):
|
||||
if self.use_snorm:
|
||||
self.snorms.append(SNorm(self.channels))
|
||||
if self.use_tnorm:
|
||||
self.tnorms.append(TNorm(self.num_nodes, self.channels))
|
||||
|
||||
# 输出层
|
||||
self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True)
|
||||
self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True)
|
||||
self.filter_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation))
|
||||
self.gate_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation))
|
||||
self.residual_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1)))
|
||||
self.skip_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1)))
|
||||
|
||||
self.receptive_field += rf_add
|
||||
rf_add *= 2
|
||||
dilation *= 2
|
||||
|
||||
self.end_conv_1 = nn.Conv2d(self.channels, self.channels, (1, 1))
|
||||
self.end_conv_2 = nn.Conv2d(self.channels, self.out_dim, (1, 1))
|
||||
|
||||
def forward(self, input):
|
||||
# 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps)
|
||||
x = input[..., 0:1].transpose(1, 3)
|
||||
|
||||
# 处理感受野
|
||||
in_len = x.size(3)
|
||||
if in_len < self.receptive_field:
|
||||
x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0))
|
||||
|
||||
# 起始卷积
|
||||
x = self.start_conv(x)
|
||||
skip = 0
|
||||
# (B, T, N, F) -> (B, F, N, T)
|
||||
x = input[..., :self.in_dim].transpose(1, 3)
|
||||
|
||||
# WaveNet层
|
||||
# pad to receptive field
|
||||
if x.size(3) < self.receptive_field:
|
||||
x = F.pad(x, (self.receptive_field - x.size(3), 0, 0, 0))
|
||||
|
||||
x = self.start_conv(x)
|
||||
skip = None
|
||||
|
||||
norm_idx = 0
|
||||
for i in range(self.blocks * self.layers):
|
||||
residual = x
|
||||
|
||||
# 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接)
|
||||
x_norm = x
|
||||
if self.tnorm_bool:
|
||||
x_norm += self.tn[i](x)
|
||||
if self.snorm_bool:
|
||||
x_norm += self.sn[i](x)
|
||||
|
||||
# 膨胀卷积
|
||||
filter = torch.tanh(self.filter_convs[i](x_norm))
|
||||
gate = torch.sigmoid(self.gate_convs[i](x_norm))
|
||||
x = filter * gate
|
||||
|
||||
# 跳跃连接
|
||||
# ---------- STNorm (safe fusion) ----------
|
||||
if self.use_tnorm:
|
||||
x = x + 0.5 * self.tnorms[norm_idx](x)
|
||||
if self.use_snorm:
|
||||
x = x + 0.5 * self.snorms[norm_idx](x)
|
||||
norm_idx += 1
|
||||
# ---------- Dilated Conv ----------
|
||||
filter_out = torch.tanh(self.filter_convs[i](x))
|
||||
gate_out = torch.sigmoid(self.gate_convs[i](x))
|
||||
x = filter_out * gate_out
|
||||
# ---------- Skip (TIME SAFE) ----------
|
||||
s = self.skip_convs[i](x)
|
||||
skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0)
|
||||
|
||||
# 残差连接
|
||||
x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):]
|
||||
if skip is None:
|
||||
skip = s
|
||||
else:
|
||||
skip = skip[..., -s.size(3) :] + s
|
||||
# ---------- Residual (TIME SAFE) ----------
|
||||
x = self.residual_convs[i](x)
|
||||
x = x + residual[..., -x.size(3) :]
|
||||
|
||||
# 输出处理
|
||||
x = F.relu(skip)
|
||||
x = F.relu(self.end_conv_1(x))
|
||||
x = self.end_conv_2(x)
|
||||
return x
|
||||
x = self.end_conv_2(x) # [B, 1, N, T]
|
||||
T_out = x.size(3)
|
||||
T_target = input.size(1)
|
||||
|
||||
if T_out < T_target:
|
||||
x = F.pad(x, (T_target - T_out, 0, 0, 0)) # left pad
|
||||
|
||||
x = x.transpose(1, 3)
|
||||
return x
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@
|
|||
{
|
||||
"name": "STNorm",
|
||||
"module": "model.STNorm.STNorm",
|
||||
"entry": "stnorm"
|
||||
"entry": "STNormNet"
|
||||
}
|
||||
]
|
||||
10
train.py
10
train.py
|
|
@ -11,7 +11,7 @@ def read_config(config_path):
|
|||
config = yaml.safe_load(file)
|
||||
|
||||
# 全局配置
|
||||
device = "cuda:0" # 指定设备为cuda:0
|
||||
device = "cpu" # 指定设备为cuda:0
|
||||
seed = 2023 # 随机种子
|
||||
epochs = 1 # 训练轮数
|
||||
|
||||
|
|
@ -90,9 +90,9 @@ def main(model, data, debug=False):
|
|||
if __name__ == "__main__":
|
||||
# 调试用
|
||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||
model_list = ["FPT"]
|
||||
model_list = ["STNorm"]
|
||||
# model_list = ["PatchTST"]
|
||||
dataset_list = ["METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"]
|
||||
# dataset_list = ["AirQuality"]
|
||||
# dataset_list = ["METR-LA"]
|
||||
main(model_list, dataset_list, debug = False)
|
||||
dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
|
||||
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
|
||||
main(model_list, dataset_list, debug = True)
|
||||
Loading…
Reference in New Issue