From 9d3293cef7baf23be204144311b90f2b5e426dae Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 20 Dec 2025 15:45:13 +0800 Subject: [PATCH] impl STNorm --- config/STNorm/AirQuality.yaml | 34 ++--- config/STNorm/BJTaxi-InFlow.yaml | 34 ++--- config/STNorm/BJTaxi-OutFlow.yaml | 34 ++--- config/STNorm/METR-LA.yaml | 2 +- config/STNorm/NYCBike-InFlow.yaml | 34 ++--- config/STNorm/NYCBike-OutFlow.yaml | 34 ++--- config/STNorm/PEMS-BAY.yaml | 34 ++--- config/STNorm/SolarEnergy.yaml | 34 ++--- model/STNorm/STNorm.py | 206 +++++++++++++++-------------- model/STNorm/model_config.json | 2 +- train.py | 10 +- 11 files changed, 191 insertions(+), 267 deletions(-) diff --git a/config/STNorm/AirQuality.yaml b/config/STNorm/AirQuality.yaml index 9846895..384633d 100644 --- a/config/STNorm/AirQuality.yaml +++ b/config/STNorm/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 35 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 6 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 6 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 35 + in_dim: 6 + out_dim: 6 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/BJTaxi-InFlow.yaml b/config/STNorm/BJTaxi-InFlow.yaml index 09e453a..13130be 100644 --- a/config/STNorm/BJTaxi-InFlow.yaml +++ b/config/STNorm/BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 1024 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 1024 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/BJTaxi-OutFlow.yaml b/config/STNorm/BJTaxi-OutFlow.yaml index 1b62a4e..fec550a 100644 --- a/config/STNorm/BJTaxi-OutFlow.yaml +++ b/config/STNorm/BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 1024 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 1024 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/METR-LA.yaml b/config/STNorm/METR-LA.yaml index 6f118f2..f48c978 100644 --- a/config/STNorm/METR-LA.yaml +++ b/config/STNorm/METR-LA.yaml @@ -26,7 +26,7 @@ model: tnorm_bool: True num_nodes: 207 in_dim: 1 - out_dim: 24 + out_dim: 1 channels: 32 kernel_size: 2 diff --git a/config/STNorm/NYCBike-InFlow.yaml b/config/STNorm/NYCBike-InFlow.yaml index 95ae41b..57ad401 100644 --- a/config/STNorm/NYCBike-InFlow.yaml +++ b/config/STNorm/NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 128 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 128 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/NYCBike-OutFlow.yaml b/config/STNorm/NYCBike-OutFlow.yaml index b1646ea..4f32f0a 100644 --- a/config/STNorm/NYCBike-OutFlow.yaml +++ b/config/STNorm/NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 128 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 128 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/PEMS-BAY.yaml b/config/STNorm/PEMS-BAY.yaml index 7f28aca..20f4b5d 100644 --- a/config/STNorm/PEMS-BAY.yaml +++ b/config/STNorm/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 325 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 325 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/config/STNorm/SolarEnergy.yaml b/config/STNorm/SolarEnergy.yaml index 57e17c8..d1be59c 100644 --- a/config/STNorm/SolarEnergy.yaml +++ b/config/STNorm/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: MTGNN + model: STNorm seed: 2023 data: @@ -19,28 +19,16 @@ data: val_ratio: 0.2 model: - gcn_true: True # 是否使用图卷积网络 (bool) - buildA_true: True # 是否动态构建邻接矩阵 (bool) - subgraph_size: 20 # 子图大小 (int) - num_nodes: 137 # 节点数量 (int) - node_dim: 40 # 节点嵌入维度 (int) - dilation_exponential: 1 # 膨胀卷积指数 (int) - conv_channels: 32 # 卷积通道数 (int) - residual_channels: 32 # 残差通道数 (int) - skip_channels: 64 # 跳跃连接通道数 (int) - end_channels: 128 # 输出层通道数 (int) - seq_len: 24 # 输入序列长度 (int) - in_dim: 1 # 输入特征维度 (int) - out_len: 24 # 输出序列长度 (int) - out_dim: 1 # 输出预测维度 (int) - layers: 3 # 模型层数 (int) - propalpha: 0.05 # 图传播参数alpha (float) - tanhalpha: 3 # tanh激活参数alpha (float) - layer_norm_affline: True # 层归一化是否使用affine变换 (bool) - gcn_depth: 2 # 图卷积深度 (int) - dropout: 0.3 # dropout率 (float) - predefined_A: null # 预定义邻接矩阵 (optional, None) - static_feat: null # 静态特征 (optional, None) + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 137 + in_dim: 1 + out_dim: 1 + channels: 32 + kernel_size: 2 train: batch_size: 64 diff --git a/model/STNorm/STNorm.py b/model/STNorm/STNorm.py index 71e7118..11a72e4 100644 --- a/model/STNorm/STNorm.py +++ b/model/STNorm/STNorm.py @@ -2,139 +2,147 @@ import torch import torch.nn as nn import torch.nn.functional as F + +# ========================= +# Spatial Normalization +# ========================= class SNorm(nn.Module): - def __init__(self, channels): + def __init__(self, channels, eps=1e-5): super().__init__() - self.beta = nn.Parameter(torch.zeros(channels)) - self.gamma = nn.Parameter(torch.ones(channels)) + self.gamma = nn.Parameter(torch.ones(1, channels, 1, 1)) + self.beta = nn.Parameter(torch.zeros(1, channels, 1, 1)) + self.eps = eps def forward(self, x): - x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5 - return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1) + # normalize over node dimension + mean = x.mean(dim=2, keepdim=True) + var = x.var(dim=2, keepdim=True, unbiased=False) + x = (x - mean) / torch.sqrt(var + self.eps) + return x * self.gamma + self.beta + +# ========================= +# Temporal Normalization +# ========================= class TNorm(nn.Module): - def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1): + def __init__(self, num_nodes, channels, momentum=0.1, eps=1e-5): super().__init__() - self.track_running_stats = track_running_stats - self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) - self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1)) - self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1)) + self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer("running_mean", torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer("running_var", torch.ones(1, channels, num_nodes, 1)) self.momentum = momentum + self.eps = eps def forward(self, x): - if self.track_running_stats: - mean = x.mean((0, 3), keepdims=True) - var = x.var((0, 3), keepdims=True, unbiased=False) - if self.training: - n = x.shape[3] * x.shape[0] - with torch.no_grad(): - self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean - self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var - else: - mean = self.running_mean - var = self.running_var + if self.training: + mean = x.mean(dim=(0, 3), keepdim=True) + var = x.var(dim=(0, 3), keepdim=True, unbiased=False) + # in-place update (VERY IMPORTANT) + self.running_mean.mul_(1 - self.momentum).add_(self.momentum * mean) + self.running_var.mul_(1 - self.momentum).add_(self.momentum * var) else: - mean = x.mean(3, keepdims=True) - var = x.var(3, keepdims=True, unbiased=True) - x_norm = (x - mean) / (var + 1e-5) ** 0.5 - return x_norm * self.gamma + self.beta + mean = self.running_mean + var = self.running_var -class stnorm(nn.Module): + x = (x - mean) / torch.sqrt(var + self.eps) + return x * self.gamma + self.beta + + +# ========================= +# STNorm WaveNet +# ========================= +class STNormNet(nn.Module): def __init__(self, args): super().__init__() - self.dropout = args["dropout"] self.blocks = args["blocks"] self.layers = args["layers"] - self.snorm_bool = args["snorm_bool"] - self.tnorm_bool = args["tnorm_bool"] + self.dropout = args["dropout"] self.num_nodes = args["num_nodes"] - in_dim = args["in_dim"] - out_dim = args["out_dim"] - channels = args["channels"] - kernel_size = args["kernel_size"] - # 初始化卷积层 - self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1)) - - # 初始化模块列表 + self.in_dim = args["in_dim"] + self.out_dim = args["out_dim"] + self.channels = args["channels"] + self.kernel_size = args["kernel_size"] + + self.use_snorm = args["snorm_bool"] + self.use_tnorm = args["tnorm_bool"] + + self.start_conv = nn.Conv2d(self.in_dim, self.channels, kernel_size=(1, 1)) + self.filter_convs = nn.ModuleList() self.gate_convs = nn.ModuleList() self.residual_convs = nn.ModuleList() self.skip_convs = nn.ModuleList() - self.sn = nn.ModuleList() if self.snorm_bool else None - self.tn = nn.ModuleList() if self.tnorm_bool else None - # 计算感受野 + self.snorms = nn.ModuleList() + self.tnorms = nn.ModuleList() + self.receptive_field = 1 - additional_scope = kernel_size - 1 - # 构建网络层 for b in range(self.blocks): - new_dilation = 1 - for i in range(self.layers): - if self.tnorm_bool: - self.tn.append(TNorm(self.num_nodes, channels)) - if self.snorm_bool: - self.sn.append(SNorm(channels)) - - # 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征 - self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, - kernel_size=(1, kernel_size), dilation=new_dilation)) - - # 残差连接和跳跃连接 - self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) - self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) - - # 更新感受野 - self.receptive_field += additional_scope - additional_scope *= 2 - new_dilation *= 2 + dilation = 1 + rf_add = self.kernel_size - 1 + for _ in range(self.layers): + if self.use_snorm: + self.snorms.append(SNorm(self.channels)) + if self.use_tnorm: + self.tnorms.append(TNorm(self.num_nodes, self.channels)) - # 输出层 - self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True) - self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) + self.filter_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) + self.gate_convs.append(nn.Conv2d(self.channels, self.channels, (1, self.kernel_size), dilation=dilation)) + self.residual_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) + self.skip_convs.append(nn.Conv2d(self.channels, self.channels, (1, 1))) + + self.receptive_field += rf_add + rf_add *= 2 + dilation *= 2 + + self.end_conv_1 = nn.Conv2d(self.channels, self.channels, (1, 1)) + self.end_conv_2 = nn.Conv2d(self.channels, self.out_dim, (1, 1)) def forward(self, input): - # 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps) - x = input[..., 0:1].transpose(1, 3) - - # 处理感受野 - in_len = x.size(3) - if in_len < self.receptive_field: - x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) - - # 起始卷积 - x = self.start_conv(x) - skip = 0 + # (B, T, N, F) -> (B, F, N, T) + x = input[..., :self.in_dim].transpose(1, 3) - # WaveNet层 + # pad to receptive field + if x.size(3) < self.receptive_field: + x = F.pad(x, (self.receptive_field - x.size(3), 0, 0, 0)) + + x = self.start_conv(x) + skip = None + + norm_idx = 0 for i in range(self.blocks * self.layers): residual = x - - # 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接) - x_norm = x - if self.tnorm_bool: - x_norm += self.tn[i](x) - if self.snorm_bool: - x_norm += self.sn[i](x) - - # 膨胀卷积 - filter = torch.tanh(self.filter_convs[i](x_norm)) - gate = torch.sigmoid(self.gate_convs[i](x_norm)) - x = filter * gate - - # 跳跃连接 + # ---------- STNorm (safe fusion) ---------- + if self.use_tnorm: + x = x + 0.5 * self.tnorms[norm_idx](x) + if self.use_snorm: + x = x + 0.5 * self.snorms[norm_idx](x) + norm_idx += 1 + # ---------- Dilated Conv ---------- + filter_out = torch.tanh(self.filter_convs[i](x)) + gate_out = torch.sigmoid(self.gate_convs[i](x)) + x = filter_out * gate_out + # ---------- Skip (TIME SAFE) ---------- s = self.skip_convs[i](x) - skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) - - # 残差连接 - x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):] + if skip is None: + skip = s + else: + skip = skip[..., -s.size(3) :] + s + # ---------- Residual (TIME SAFE) ---------- + x = self.residual_convs[i](x) + x = x + residual[..., -x.size(3) :] - # 输出处理 x = F.relu(skip) x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) - return x \ No newline at end of file + x = self.end_conv_2(x) # [B, 1, N, T] + T_out = x.size(3) + T_target = input.size(1) + + if T_out < T_target: + x = F.pad(x, (T_target - T_out, 0, 0, 0)) # left pad + + x = x.transpose(1, 3) + return x diff --git a/model/STNorm/model_config.json b/model/STNorm/model_config.json index 62ea48c..f860d07 100644 --- a/model/STNorm/model_config.json +++ b/model/STNorm/model_config.json @@ -2,6 +2,6 @@ { "name": "STNorm", "module": "model.STNorm.STNorm", - "entry": "stnorm" + "entry": "STNormNet" } ] \ No newline at end of file diff --git a/train.py b/train.py index 7242ac0..db9d8dd 100644 --- a/train.py +++ b/train.py @@ -11,7 +11,7 @@ def read_config(config_path): config = yaml.safe_load(file) # 全局配置 - device = "cuda:0" # 指定设备为cuda:0 + device = "cpu" # 指定设备为cuda:0 seed = 2023 # 随机种子 epochs = 1 # 训练轮数 @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["FPT"] + model_list = ["STNorm"] # model_list = ["PatchTST"] - dataset_list = ["METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - # dataset_list = ["METR-LA"] - main(model_list, dataset_list, debug = False) \ No newline at end of file + dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] + # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] + main(model_list, dataset_list, debug = True) \ No newline at end of file