更新iTransformer, HI配置。更新TS数据集载入方式
This commit is contained in:
parent
560d24e5a8
commit
44ffe94c95
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 35
|
output_dim: 6
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 2048
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 2048
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 1024
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 2048
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 2048
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 1024
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 207
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 128
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 128
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 325
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -25,7 +25,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 512
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -40,7 +40,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 137
|
output_dim: 1
|
||||||
optimizer: null
|
optimizer: null
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 35
|
output_dim: 6
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 2048
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 2048
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 1024
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 2048
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 2048
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 1024
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 207
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 128
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 128
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 325
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
|
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 256
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
@ -46,7 +46,7 @@ train:
|
||||||
mae_thresh: None
|
mae_thresh: None
|
||||||
mape_thresh: 0.001
|
mape_thresh: 0.001
|
||||||
max_grad_norm: 5
|
max_grad_norm: 5
|
||||||
output_dim: 137
|
output_dim: 1
|
||||||
plot: false
|
plot: false
|
||||||
real_value: true
|
real_value: true
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|
@ -7,16 +7,16 @@ import torch
|
||||||
|
|
||||||
def get_dataloader(args, normalizer="std", single=True):
|
def get_dataloader(args, normalizer="std", single=True):
|
||||||
data = load_st_dataset(args)
|
data = load_st_dataset(args)
|
||||||
data = data[..., 0:1]
|
# data = data[..., 0:1]
|
||||||
|
|
||||||
args = args["data"]
|
args = args["data"]
|
||||||
L, N, F = data.shape
|
L, N, F = data.shape
|
||||||
data = data.reshape(L, N*F) # [L, N*F]
|
# data = data.reshape(L, N*F) # [L, N*F]
|
||||||
|
|
||||||
# Generate sliding windows for main data and add time features
|
# Generate sliding windows for main data and add time features
|
||||||
x, y = _prepare_data_with_windows(data, args, single)
|
x, y = _prepare_data_with_windows(data, args, single)
|
||||||
|
|
||||||
# Split data
|
# Split data [b,t,n,c]
|
||||||
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
|
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
|
||||||
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
|
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
|
||||||
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
|
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
|
||||||
|
|
@ -25,6 +25,10 @@ def get_dataloader(args, normalizer="std", single=True):
|
||||||
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
|
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
|
||||||
_apply_existing_scaler(y_train, y_val, y_test, scaler, args)
|
_apply_existing_scaler(y_train, y_val, y_test, scaler, args)
|
||||||
|
|
||||||
|
# reshape [b,t,n,c] -> [b*n, t, c]
|
||||||
|
x_train, x_val, x_test, y_train, y_val, y_test = \
|
||||||
|
_reshape_tensor(x_train, x_val, x_test, y_train, y_val, y_test)
|
||||||
|
|
||||||
# Create dataloaders
|
# Create dataloaders
|
||||||
return (
|
return (
|
||||||
_create_dataloader(x_train, y_train, args["batch_size"], True, False),
|
_create_dataloader(x_train, y_train, args["batch_size"], True, False),
|
||||||
|
|
@ -33,6 +37,16 @@ def get_dataloader(args, normalizer="std", single=True):
|
||||||
scaler
|
scaler
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _reshape_tensor(*tensors):
|
||||||
|
"""Reshape tensors from [b, t, n, c] -> [b*n, t, c]."""
|
||||||
|
reshaped = []
|
||||||
|
for x in tensors:
|
||||||
|
# x 是 ndarray:shape (b, t, n, c)
|
||||||
|
b, t, n, c = x.shape
|
||||||
|
x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c)
|
||||||
|
reshaped.append(x_new)
|
||||||
|
return reshaped
|
||||||
|
|
||||||
def _prepare_data_with_windows(data, args, single):
|
def _prepare_data_with_windows(data, args, single):
|
||||||
# Generate sliding windows for main data
|
# Generate sliding windows for main data
|
||||||
x = add_window_x(data, args["lag"], args["horizon"], single)
|
x = add_window_x(data, args["lag"], args["horizon"], single)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue