bug fixes
This commit is contained in:
parent
2bfa444f8e
commit
2c54d81a67
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
stride: 7
|
stride: 7
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ basic:
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
column_wise: false
|
column_wise: false
|
||||||
days_per_week: 7
|
days_per_week: 7
|
||||||
horizon: 24
|
horizon: 24
|
||||||
|
|
@ -31,7 +31,7 @@ model:
|
||||||
stride: 7
|
stride: 7
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 32
|
batch_size: 16
|
||||||
debug: false
|
debug: false
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ model:
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
num_nodes: 1024
|
num_nodes: 1024
|
||||||
|
output_dim: 1
|
||||||
patch_len: 6
|
patch_len: 6
|
||||||
pred_len: 24
|
pred_len: 24
|
||||||
seq_len: 24
|
seq_len: 24
|
||||||
|
|
|
||||||
2
train.py
2
train.py
|
|
@ -110,7 +110,7 @@ def read_config(config_path):
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 调试用
|
# 调试用
|
||||||
# model_list = ["iTransformer", "PatchTST", "HI"]
|
# model_list = ["iTransformer", "PatchTST", "HI"]
|
||||||
model_list = ["STID"]
|
model_list = ["REPST"]
|
||||||
# model_list = ["PatchTST"]
|
# model_list = ["PatchTST"]
|
||||||
|
|
||||||
air = ["AirQuality"]
|
air = ["AirQuality"]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue