From 44ffe94c95edd61ed3b7f0d0333a87ce445468cf Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 10 Dec 2025 11:01:00 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0iTransformer,=20HI=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E3=80=82=E6=9B=B4=E6=96=B0TS=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E8=BD=BD=E5=85=A5=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/HI/AirQuality.yaml | 6 +++--- config/HI/BJTaxi-Inflow.yaml | 6 +++--- config/HI/BJTaxi-Outflow.yaml | 6 +++--- config/HI/METR-LA.yaml | 6 +++--- config/HI/NYCBike-Inflow.yaml | 6 +++--- config/HI/NYCBike-Outflow.yaml | 6 +++--- config/HI/PEMS-BAY.yaml | 6 +++--- config/HI/SolarEnergy.yaml | 6 +++--- config/iTransformer/AirQuality.yaml | 6 +++--- config/iTransformer/BJTaxi-Inflow.yaml | 6 +++--- config/iTransformer/BJTaxi-Outflow.yaml | 6 +++--- config/iTransformer/METR-LA.yaml | 6 +++--- config/iTransformer/NYCBike-Inflow.yaml | 6 +++--- config/iTransformer/NYCBike-Outflow.yaml | 6 +++--- config/iTransformer/PEMS-BAY.yaml | 6 +++--- config/iTransformer/SolarEnergy.yaml | 6 +++--- dataloader/TSloader.py | 20 +++++++++++++++++--- 17 files changed, 65 insertions(+), 51 deletions(-) diff --git a/config/HI/AirQuality.yaml b/config/HI/AirQuality.yaml index 07300c4..147da8a 100644 --- a/config/HI/AirQuality.yaml +++ b/config/HI/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 35 + output_dim: 6 optimizer: null plot: false real_value: true diff --git a/config/HI/BJTaxi-Inflow.yaml b/config/HI/BJTaxi-Inflow.yaml index d752667..d3b39ea 100644 --- a/config/HI/BJTaxi-Inflow.yaml +++ b/config/HI/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/BJTaxi-Outflow.yaml b/config/HI/BJTaxi-Outflow.yaml index 271fbc7..96f4253 100644 --- a/config/HI/BJTaxi-Outflow.yaml +++ b/config/HI/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/METR-LA.yaml b/config/HI/METR-LA.yaml index 0826302..203db0d 100644 --- a/config/HI/METR-LA.yaml +++ b/config/HI/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 207 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/NYCBike-Inflow.yaml b/config/HI/NYCBike-Inflow.yaml index be217a9..a24a481 100644 --- a/config/HI/NYCBike-Inflow.yaml +++ b/config/HI/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/NYCBike-Outflow.yaml b/config/HI/NYCBike-Outflow.yaml index 0f93fe5..87d6156 100644 --- a/config/HI/NYCBike-Outflow.yaml +++ b/config/HI/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/PEMS-BAY.yaml b/config/HI/PEMS-BAY.yaml index 832f455..e012772 100644 --- a/config/HI/PEMS-BAY.yaml +++ b/config/HI/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 325 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/HI/SolarEnergy.yaml b/config/HI/SolarEnergy.yaml index 8f55fac..1558d3a 100644 --- a/config/HI/SolarEnergy.yaml +++ b/config/HI/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 512 column_wise: false days_per_week: 7 horizon: 24 @@ -25,7 +25,7 @@ model: train: - batch_size: 16 + batch_size: 512 debug: false early_stop: true early_stop_patience: 15 @@ -40,7 +40,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 optimizer: null plot: false real_value: true diff --git a/config/iTransformer/AirQuality.yaml b/config/iTransformer/AirQuality.yaml index 74bf69d..23eba27 100644 --- a/config/iTransformer/AirQuality.yaml +++ b/config/iTransformer/AirQuality.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 35 + output_dim: 6 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Inflow.yaml b/config/iTransformer/BJTaxi-Inflow.yaml index 8a0e7c9..dfc2df2 100644 --- a/config/iTransformer/BJTaxi-Inflow.yaml +++ b/config/iTransformer/BJTaxi-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/BJTaxi-Outflow.yaml b/config/iTransformer/BJTaxi-Outflow.yaml index ea4af50..d14bed5 100644 --- a/config/iTransformer/BJTaxi-Outflow.yaml +++ b/config/iTransformer/BJTaxi-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 2048 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 2048 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 1024 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/METR-LA.yaml b/config/iTransformer/METR-LA.yaml index 3d02d8b..20c4068 100644 --- a/config/iTransformer/METR-LA.yaml +++ b/config/iTransformer/METR-LA.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 207 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/NYCBike-Inflow.yaml b/config/iTransformer/NYCBike-Inflow.yaml index 598ca1e..8afa656 100644 --- a/config/iTransformer/NYCBike-Inflow.yaml +++ b/config/iTransformer/NYCBike-Inflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/NYCBike-Outflow.yaml b/config/iTransformer/NYCBike-Outflow.yaml index b6a8994..7abba88 100644 --- a/config/iTransformer/NYCBike-Outflow.yaml +++ b/config/iTransformer/NYCBike-Outflow.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 32 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 128 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/PEMS-BAY.yaml b/config/iTransformer/PEMS-BAY.yaml index 5140b73..17f2fd4 100644 --- a/config/iTransformer/PEMS-BAY.yaml +++ b/config/iTransformer/PEMS-BAY.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 325 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/config/iTransformer/SolarEnergy.yaml b/config/iTransformer/SolarEnergy.yaml index bab0108..cce005a 100644 --- a/config/iTransformer/SolarEnergy.yaml +++ b/config/iTransformer/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 16 + batch_size: 256 column_wise: false days_per_week: 7 horizon: 24 @@ -31,7 +31,7 @@ model: train: - batch_size: 16 + batch_size: 256 debug: false early_stop: true early_stop_patience: 15 @@ -46,7 +46,7 @@ train: mae_thresh: None mape_thresh: 0.001 max_grad_norm: 5 - output_dim: 137 + output_dim: 1 plot: false real_value: true weight_decay: 0 \ No newline at end of file diff --git a/dataloader/TSloader.py b/dataloader/TSloader.py index 33b5a17..66ef45c 100755 --- a/dataloader/TSloader.py +++ b/dataloader/TSloader.py @@ -7,16 +7,16 @@ import torch def get_dataloader(args, normalizer="std", single=True): data = load_st_dataset(args) - data = data[..., 0:1] + # data = data[..., 0:1] args = args["data"] L, N, F = data.shape - data = data.reshape(L, N*F) # [L, N*F] + # data = data.reshape(L, N*F) # [L, N*F] # Generate sliding windows for main data and add time features x, y = _prepare_data_with_windows(data, args, single) - # Split data + # Split data [b,t,n,c] split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) @@ -25,6 +25,10 @@ def get_dataloader(args, normalizer="std", single=True): scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) _apply_existing_scaler(y_train, y_val, y_test, scaler, args) + # reshape [b,t,n,c] -> [b*n, t, c] + x_train, x_val, x_test, y_train, y_val, y_test = \ + _reshape_tensor(x_train, x_val, x_test, y_train, y_val, y_test) + # Create dataloaders return ( _create_dataloader(x_train, y_train, args["batch_size"], True, False), @@ -33,6 +37,16 @@ def get_dataloader(args, normalizer="std", single=True): scaler ) +def _reshape_tensor(*tensors): + """Reshape tensors from [b, t, n, c] -> [b*n, t, c].""" + reshaped = [] + for x in tensors: + # x 是 ndarray:shape (b, t, n, c) + b, t, n, c = x.shape + x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c) + reshaped.append(x_new) + return reshaped + def _prepare_data_with_windows(data, args, single): # Generate sliding windows for main data x = add_window_x(data, args["lag"], args["horizon"], single)