更新iTransformer, HI配置。更新TS数据集载入方式

This commit is contained in:
czzhangheng 2025-12-10 11:01:00 +08:00
parent 560d24e5a8
commit 44ffe94c95
17 changed files with 65 additions and 51 deletions

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 35 output_dim: 6
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 2048
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 2048
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 1024 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 2048
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 2048
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 1024 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 207 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 128 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 128 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 325 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 512
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -25,7 +25,7 @@ model:
train: train:
batch_size: 16 batch_size: 512
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -40,7 +40,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 137 output_dim: 1
optimizer: null optimizer: null
plot: false plot: false
real_value: true real_value: true

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 35 output_dim: 6
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 2048
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 2048
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 1024 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 2048
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 2048
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 1024 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 207 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 128 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 32 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 128 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 325 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -6,7 +6,7 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 256
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
@ -31,7 +31,7 @@ model:
train: train:
batch_size: 16 batch_size: 256
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15
@ -46,7 +46,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 137 output_dim: 1
plot: false plot: false
real_value: true real_value: true
weight_decay: 0 weight_decay: 0

View File

@ -7,16 +7,16 @@ import torch
def get_dataloader(args, normalizer="std", single=True): def get_dataloader(args, normalizer="std", single=True):
data = load_st_dataset(args) data = load_st_dataset(args)
data = data[..., 0:1] # data = data[..., 0:1]
args = args["data"] args = args["data"]
L, N, F = data.shape L, N, F = data.shape
data = data.reshape(L, N*F) # [L, N*F] # data = data.reshape(L, N*F) # [L, N*F]
# Generate sliding windows for main data and add time features # Generate sliding windows for main data and add time features
x, y = _prepare_data_with_windows(data, args, single) x, y = _prepare_data_with_windows(data, args, single)
# Split data # Split data [b,t,n,c]
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"]) x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"]) y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
@ -25,6 +25,10 @@ def get_dataloader(args, normalizer="std", single=True):
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer) scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
_apply_existing_scaler(y_train, y_val, y_test, scaler, args) _apply_existing_scaler(y_train, y_val, y_test, scaler, args)
# reshape [b,t,n,c] -> [b*n, t, c]
x_train, x_val, x_test, y_train, y_val, y_test = \
_reshape_tensor(x_train, x_val, x_test, y_train, y_val, y_test)
# Create dataloaders # Create dataloaders
return ( return (
_create_dataloader(x_train, y_train, args["batch_size"], True, False), _create_dataloader(x_train, y_train, args["batch_size"], True, False),
@ -33,6 +37,16 @@ def get_dataloader(args, normalizer="std", single=True):
scaler scaler
) )
def _reshape_tensor(*tensors):
"""Reshape tensors from [b, t, n, c] -> [b*n, t, c]."""
reshaped = []
for x in tensors:
# x 是 ndarrayshape (b, t, n, c)
b, t, n, c = x.shape
x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c)
reshaped.append(x_new)
return reshaped
def _prepare_data_with_windows(data, args, single): def _prepare_data_with_windows(data, args, single):
# Generate sliding windows for main data # Generate sliding windows for main data
x = add_window_x(data, args["lag"], args["horizon"], single) x = add_window_x(data, args["lag"], args["horizon"], single)