fix Informer config bug

This commit is contained in:
czzhangheng 2025-12-21 17:45:51 +08:00
parent c121912f03
commit ce6959a99d
3 changed files with 11 additions and 11 deletions

View File

@ -22,7 +22,7 @@ data:
model:
activation: gelu
seq_len: 24
label_len: 12
label_len: 24
pred_len: 24
d_model: 128
d_ff: 2048
@ -59,7 +59,7 @@ train:
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 6
output_dim: 1
plot: false
pred_len: 24
real_value: true

View File

@ -6,19 +6,18 @@ basic:
seed: 2023
data:
batch_size: 16
batch_size: 64
column_wise: false
days_per_week: 7
horizon: 24
input_dim: 6
input_dim: 1
lag: 24
label_len: 24
normalizer: std
num_nodes: 137
steps_per_day: 24
test_ratio: 0.2
val_ratio: 0.2
model:
activation: gelu
seq_len: 24
@ -43,7 +42,7 @@ model:
train:
batch_size: 16
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15

View File

@ -12,7 +12,7 @@ def read_config(config_path):
config = yaml.safe_load(file)
# 全局配置
device = "cuda:0" # 指定设备为cuda:0
device = "cpu" # 指定设备为cuda:0
seed = 2023 # 随机种子
epochs = 1 # 训练轮数
@ -102,10 +102,11 @@ def main(model_list, data, debug=False):
if __name__ == "__main__":
# 调试用
# model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["Informer"]
model_list = ["iTransformer", "Informer"]
# model_list = ["PatchTST"]
# dataset_list = ["AirQuality"]
dataset_list = ["SolarEnergy"]
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
# dataset_list = ["METR-LA"]
# dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
# dataset_list = ["BJTaxi-OutFlow"]
main(model_list, dataset_list, debug=True)