fix Informer config bug

This commit is contained in:
czzhangheng 2025-12-21 17:45:51 +08:00
parent c121912f03
commit ce6959a99d
3 changed files with 11 additions and 11 deletions

View File

@ -22,7 +22,7 @@ data:
model: model:
activation: gelu activation: gelu
seq_len: 24 seq_len: 24
label_len: 12 label_len: 24
pred_len: 24 pred_len: 24
d_model: 128 d_model: 128
d_ff: 2048 d_ff: 2048
@ -59,7 +59,7 @@ train:
mae_thresh: None mae_thresh: None
mape_thresh: 0.001 mape_thresh: 0.001
max_grad_norm: 5 max_grad_norm: 5
output_dim: 6 output_dim: 1
plot: false plot: false
pred_len: 24 pred_len: 24
real_value: true real_value: true

View File

@ -6,19 +6,18 @@ basic:
seed: 2023 seed: 2023
data: data:
batch_size: 16 batch_size: 64
column_wise: false column_wise: false
days_per_week: 7 days_per_week: 7
horizon: 24 horizon: 24
input_dim: 6 input_dim: 1
lag: 24 lag: 24
label_len: 24
normalizer: std normalizer: std
num_nodes: 137 num_nodes: 137
steps_per_day: 24 steps_per_day: 24
test_ratio: 0.2 test_ratio: 0.2
val_ratio: 0.2 val_ratio: 0.2
model: model:
activation: gelu activation: gelu
seq_len: 24 seq_len: 24
@ -43,7 +42,7 @@ model:
train: train:
batch_size: 16 batch_size: 64
debug: false debug: false
early_stop: true early_stop: true
early_stop_patience: 15 early_stop_patience: 15

View File

@ -12,7 +12,7 @@ def read_config(config_path):
config = yaml.safe_load(file) config = yaml.safe_load(file)
# 全局配置 # 全局配置
device = "cuda:0" # 指定设备为cuda:0 device = "cpu" # 指定设备为cuda:0
seed = 2023 # 随机种子 seed = 2023 # 随机种子
epochs = 1 # 训练轮数 epochs = 1 # 训练轮数
@ -102,10 +102,11 @@ def main(model_list, data, debug=False):
if __name__ == "__main__": if __name__ == "__main__":
# 调试用 # 调试用
# model_list = ["iTransformer", "PatchTST", "HI"] # model_list = ["iTransformer", "PatchTST", "HI"]
model_list = ["Informer"] model_list = ["iTransformer", "Informer"]
# model_list = ["PatchTST"] # model_list = ["PatchTST"]
# dataset_list = ["AirQuality"] # dataset_list = ["AirQuality"]
dataset_list = ["SolarEnergy"]
# dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"] # dataset_list = ["BJTaxi-InFlow", "BJTaxi-OutFlow"]
dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"] # dataset_list = ["AirQuality", "PEMS-BAY", "SolarEnergy", "NYCBike-InFlow", "NYCBike-OutFlow", "METR-LA"]
# dataset_list = ["METR-LA"] # dataset_list = ["BJTaxi-OutFlow"]
main(model_list, dataset_list, debug=True) main(model_list, dataset_list, debug=True)