diff --git a/config/PatchTST/BJTaxi-InFlow.yaml b/config/PatchTST/BJTaxi-InFlow.yaml index a4e0308..95ad0b1 100644 --- a/config/PatchTST/BJTaxi-InFlow.yaml +++ b/config/PatchTST/BJTaxi-InFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/PatchTST/BJTaxi-OutFlow.yaml b/config/PatchTST/BJTaxi-OutFlow.yaml index 68c8476..f416372 100644 --- a/config/PatchTST/BJTaxi-OutFlow.yaml +++ b/config/PatchTST/BJTaxi-OutFlow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -34,7 +34,7 @@ model: train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/config/STNorm/AirQuality.yaml b/config/STNorm/AirQuality.yaml new file mode 100644 index 0000000..9846895 --- /dev/null +++ b/config/STNorm/AirQuality.yaml @@ -0,0 +1,64 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 35 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 6 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 6 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/BJTaxi-InFlow.yaml b/config/STNorm/BJTaxi-InFlow.yaml new file mode 100644 index 0000000..09e453a --- /dev/null +++ b/config/STNorm/BJTaxi-InFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/BJTaxi-OutFlow.yaml b/config/STNorm/BJTaxi-OutFlow.yaml new file mode 100644 index 0000000..1b62a4e --- /dev/null +++ b/config/STNorm/BJTaxi-OutFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/METR-LA.yaml b/config/STNorm/METR-LA.yaml new file mode 100644 index 0000000..6f118f2 --- /dev/null +++ b/config/STNorm/METR-LA.yaml @@ -0,0 +1,52 @@ +basic: + dataset: METR-LA + device: cuda:0 + mode: train + model: STNorm + seed: 2023 + +data: + batch_size: 16 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + dropout: 0.2 + blocks: 2 + layers: 2 + snorm_bool: True + tnorm_bool: True + num_nodes: 207 + in_dim: 1 + out_dim: 24 + channels: 32 + kernel_size: 2 + +train: + batch_size: 16 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/NYCBike-InFlow.yaml b/config/STNorm/NYCBike-InFlow.yaml new file mode 100644 index 0000000..95ae41b --- /dev/null +++ b/config/STNorm/NYCBike-InFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/NYCBike-OutFlow.yaml b/config/STNorm/NYCBike-OutFlow.yaml new file mode 100644 index 0000000..b1646ea --- /dev/null +++ b/config/STNorm/NYCBike-OutFlow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/PEMS-BAY.yaml b/config/STNorm/PEMS-BAY.yaml new file mode 100644 index 0000000..7f28aca --- /dev/null +++ b/config/STNorm/PEMS-BAY.yaml @@ -0,0 +1,64 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 325 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/STNorm/SolarEnergy.yaml b/config/STNorm/SolarEnergy.yaml new file mode 100644 index 0000000..57e17c8 --- /dev/null +++ b/config/STNorm/SolarEnergy.yaml @@ -0,0 +1,64 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 137 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/model/STNorm/STNorm.py b/model/STNorm/STNorm.py new file mode 100644 index 0000000..71e7118 --- /dev/null +++ b/model/STNorm/STNorm.py @@ -0,0 +1,140 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SNorm(nn.Module): + def __init__(self, channels): + super().__init__() + self.beta = nn.Parameter(torch.zeros(channels)) + self.gamma = nn.Parameter(torch.ones(channels)) + + def forward(self, x): + x_norm = (x - x.mean(2, keepdims=True)) / (x.var(2, keepdims=True, unbiased=True) + 1e-5) ** 0.5 + return x_norm * self.gamma.view(1, -1, 1, 1) + self.beta.view(1, -1, 1, 1) + +class TNorm(nn.Module): + def __init__(self, num_nodes, channels, track_running_stats=True, momentum=0.1): + super().__init__() + self.track_running_stats = track_running_stats + self.beta = nn.Parameter(torch.zeros(1, channels, num_nodes, 1)) + self.gamma = nn.Parameter(torch.ones(1, channels, num_nodes, 1)) + self.register_buffer('running_mean', torch.zeros(1, channels, num_nodes, 1)) + self.register_buffer('running_var', torch.ones(1, channels, num_nodes, 1)) + self.momentum = momentum + + def forward(self, x): + if self.track_running_stats: + mean = x.mean((0, 3), keepdims=True) + var = x.var((0, 3), keepdims=True, unbiased=False) + if self.training: + n = x.shape[3] * x.shape[0] + with torch.no_grad(): + self.running_mean = self.momentum * mean + (1 - self.momentum) * self.running_mean + self.running_var = self.momentum * var * n / (n - 1) + (1 - self.momentum) * self.running_var + else: + mean = self.running_mean + var = self.running_var + else: + mean = x.mean(3, keepdims=True) + var = x.var(3, keepdims=True, unbiased=True) + x_norm = (x - mean) / (var + 1e-5) ** 0.5 + return x_norm * self.gamma + self.beta + +class stnorm(nn.Module): + def __init__(self, args): + super().__init__() + self.dropout = args["dropout"] + self.blocks = args["blocks"] + self.layers = args["layers"] + self.snorm_bool = args["snorm_bool"] + self.tnorm_bool = args["tnorm_bool"] + self.num_nodes = args["num_nodes"] + in_dim = args["in_dim"] + out_dim = args["out_dim"] + channels = args["channels"] + kernel_size = args["kernel_size"] + + # 初始化卷积层 + self.start_conv = nn.Conv2d(in_channels=in_dim, out_channels=channels, kernel_size=(1, 1)) + + # 初始化模块列表 + self.filter_convs = nn.ModuleList() + self.gate_convs = nn.ModuleList() + self.residual_convs = nn.ModuleList() + self.skip_convs = nn.ModuleList() + self.sn = nn.ModuleList() if self.snorm_bool else None + self.tn = nn.ModuleList() if self.tnorm_bool else None + + # 计算感受野 + self.receptive_field = 1 + additional_scope = kernel_size - 1 + + # 构建网络层 + for b in range(self.blocks): + new_dilation = 1 + for i in range(self.layers): + if self.tnorm_bool: + self.tn.append(TNorm(self.num_nodes, channels)) + if self.snorm_bool: + self.sn.append(SNorm(channels)) + + # 膨胀卷积 - 直接使用channels作为输入通道,不再拼接多个特征 + self.filter_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + self.gate_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, + kernel_size=(1, kernel_size), dilation=new_dilation)) + + # 残差连接和跳跃连接 + self.residual_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + self.skip_convs.append(nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1))) + + # 更新感受野 + self.receptive_field += additional_scope + additional_scope *= 2 + new_dilation *= 2 + + # 输出层 + self.end_conv_1 = nn.Conv2d(in_channels=channels, out_channels=channels, kernel_size=(1, 1), bias=True) + self.end_conv_2 = nn.Conv2d(in_channels=channels, out_channels=out_dim, kernel_size=(1, 1), bias=True) + + def forward(self, input): + # 输入处理:与GWN保持一致 (bs, features, n_nodes, n_timesteps) + x = input[..., 0:1].transpose(1, 3) + + # 处理感受野 + in_len = x.size(3) + if in_len < self.receptive_field: + x = nn.functional.pad(x, (self.receptive_field - in_len, 0, 0, 0)) + + # 起始卷积 + x = self.start_conv(x) + skip = 0 + + # WaveNet层 + for i in range(self.blocks * self.layers): + residual = x + + # 添加空间和时间归一化(直接叠加到原始特征上,而不是拼接) + x_norm = x + if self.tnorm_bool: + x_norm += self.tn[i](x) + if self.snorm_bool: + x_norm += self.sn[i](x) + + # 膨胀卷积 + filter = torch.tanh(self.filter_convs[i](x_norm)) + gate = torch.sigmoid(self.gate_convs[i](x_norm)) + x = filter * gate + + # 跳跃连接 + s = self.skip_convs[i](x) + skip = s + (skip[:, :, :, -s.size(3):] if isinstance(skip, torch.Tensor) else 0) + + # 残差连接 + x = self.residual_convs[i](x) + residual[:, :, :, -x.size(3):] + + # 输出处理 + x = F.relu(skip) + x = F.relu(self.end_conv_1(x)) + x = self.end_conv_2(x) + return x \ No newline at end of file diff --git a/model/STNorm/model_config.json b/model/STNorm/model_config.json new file mode 100644 index 0000000..62ea48c --- /dev/null +++ b/model/STNorm/model_config.json @@ -0,0 +1,7 @@ +[ + { + "name": "STNorm", + "module": "model.STNorm.STNorm", + "entry": "stnorm" + } +] \ No newline at end of file diff --git a/train.py b/train.py index e9db08b..40b3bcb 100644 --- a/train.py +++ b/train.py @@ -90,9 +90,9 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["ASTRA_v3"] - # model_list = ["MTGNN"] - # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] + # model_list = ["ASTRA_v3"] + model_list = ["PatchTST"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"] - dataset_list = ["AirQuality", "SolarEnergy", "METR-LA", "NYCBike-InFlow", "NYCBike-OutFlow"] - main(model_list, dataset_list, debug = True) \ No newline at end of file + # dataset_list = ["METR-LA"] + main(model_list, dataset_list, debug = False) \ No newline at end of file diff --git a/trainer/TSTrainer.py b/trainer/TSTrainer.py index 3ddf361..11cd431 100755 --- a/trainer/TSTrainer.py +++ b/trainer/TSTrainer.py @@ -102,7 +102,9 @@ class Trainer: for data, target in self.test_loader: data, target = data.to(self.device), target.to(self.device) label = target[..., :self.args["output_dim"]] - y_pred.append(self.model(data).cpu()) + x, shp = self.pack(data) + out = self.unpack(self.model(x), shp) + y_pred.append(out.cpu()) y_true.append(label.cpu()) d_pred, d_true = self.inv(torch.cat(y_pred)), self.inv(torch.cat(y_true)) # 反归一化