From 6545482a586a0ea418621757399d729297a4e36f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Mon, 18 Aug 2025 21:57:45 +0800 Subject: [PATCH] new-model,MEGA stssl --- config/MegaCRN/PEMSD3.yaml | 52 +++++++++++++++++++++++++++++++++++++ config/MegaCRN/PEMSD4.yaml | 52 +++++++++++++++++++++++++++++++++++++ config/MegaCRN/PEMSD7.yaml | 52 +++++++++++++++++++++++++++++++++++++ config/MegaCRN/PEMSD8.yaml | 52 +++++++++++++++++++++++++++++++++++++ config/ST_SSL/PEMSD3.yaml | 53 ++++++++++++++++++++++++++++++++++++++ config/ST_SSL/PEMSD4.yaml | 53 ++++++++++++++++++++++++++++++++++++++ config/ST_SSL/PEMSD7.yaml | 53 ++++++++++++++++++++++++++++++++++++++ config/ST_SSL/PEMSD8.yaml | 53 ++++++++++++++++++++++++++++++++++++++ model/model_selector.py | 4 +++ 9 files changed, 424 insertions(+) create mode 100644 config/MegaCRN/PEMSD3.yaml create mode 100644 config/MegaCRN/PEMSD4.yaml create mode 100644 config/MegaCRN/PEMSD7.yaml create mode 100644 config/MegaCRN/PEMSD8.yaml create mode 100644 config/ST_SSL/PEMSD3.yaml create mode 100644 config/ST_SSL/PEMSD4.yaml create mode 100644 config/ST_SSL/PEMSD7.yaml create mode 100644 config/ST_SSL/PEMSD8.yaml diff --git a/config/MegaCRN/PEMSD3.yaml b/config/MegaCRN/PEMSD3.yaml new file mode 100644 index 0000000..4d90c7c --- /dev/null +++ b/config/MegaCRN/PEMSD3.yaml @@ -0,0 +1,52 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + rnn_units: 64 + num_layers: 1 + cheb_k: 3 + ycov_dim: 1 + mem_num: 20 + mem_dim: 64 + cl_decay_steps: 2000 + use_curriculum_learning: True + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 200 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/MegaCRN/PEMSD4.yaml b/config/MegaCRN/PEMSD4.yaml new file mode 100644 index 0000000..cebb3be --- /dev/null +++ b/config/MegaCRN/PEMSD4.yaml @@ -0,0 +1,52 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + rnn_units: 64 + num_layers: 1 + cheb_k: 3 + ycov_dim: 1 + mem_num: 20 + mem_dim: 64 + cl_decay_steps: 2000 + use_curriculum_learning: True + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 200 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/MegaCRN/PEMSD7.yaml b/config/MegaCRN/PEMSD7.yaml new file mode 100644 index 0000000..965ef14 --- /dev/null +++ b/config/MegaCRN/PEMSD7.yaml @@ -0,0 +1,52 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + rnn_units: 64 + num_layers: 1 + cheb_k: 3 + ycov_dim: 1 + mem_num: 20 + mem_dim: 64 + cl_decay_steps: 2000 + use_curriculum_learning: True + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/MegaCRN/PEMSD8.yaml b/config/MegaCRN/PEMSD8.yaml new file mode 100644 index 0000000..3d00c33 --- /dev/null +++ b/config/MegaCRN/PEMSD8.yaml @@ -0,0 +1,52 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + rnn_units: 64 + num_layers: 1 + cheb_k: 3 + ycov_dim: 1 + mem_num: 20 + mem_dim: 64 + cl_decay_steps: 2000 + use_curriculum_learning: True + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/ST_SSL/PEMSD3.yaml b/config/ST_SSL/PEMSD3.yaml new file mode 100644 index 0000000..3c2f488 --- /dev/null +++ b/config/ST_SSL/PEMSD3.yaml @@ -0,0 +1,53 @@ +data: + num_nodes: 358 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + n_his: 12 + d_model: 64 + dropout: 0.1 + nmb_prototype: 10 + shm_temp: 0.1 + yita: 0.5 + percent: 0.1 + gso_type: sym_norm_lap + graph_conv_type: cheb_graph_conv + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 200 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/ST_SSL/PEMSD4.yaml b/config/ST_SSL/PEMSD4.yaml new file mode 100644 index 0000000..cf44a11 --- /dev/null +++ b/config/ST_SSL/PEMSD4.yaml @@ -0,0 +1,53 @@ +data: + num_nodes: 307 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + n_his: 12 + d_model: 64 + dropout: 0.1 + nmb_prototype: 10 + shm_temp: 0.1 + yita: 0.5 + percent: 0.1 + gso_type: sym_norm_lap + graph_conv_type: cheb_graph_conv + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/ST_SSL/PEMSD7.yaml b/config/ST_SSL/PEMSD7.yaml new file mode 100644 index 0000000..1c0fe31 --- /dev/null +++ b/config/ST_SSL/PEMSD7.yaml @@ -0,0 +1,53 @@ +data: + num_nodes: 883 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + n_his: 12 + d_model: 64 + dropout: 0.1 + nmb_prototype: 10 + shm_temp: 0.1 + yita: 0.5 + percent: 0.1 + gso_type: sym_norm_lap + graph_conv_type: cheb_graph_conv + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/config/ST_SSL/PEMSD8.yaml b/config/ST_SSL/PEMSD8.yaml new file mode 100644 index 0000000..f9440e3 --- /dev/null +++ b/config/ST_SSL/PEMSD8.yaml @@ -0,0 +1,53 @@ +data: + num_nodes: 170 + lag: 12 + horizon: 12 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 288 + days_per_week: 7 + +model: + input_dim: 1 + output_dim: 1 + n_his: 12 + d_model: 64 + dropout: 0.1 + nmb_prototype: 10 + shm_temp: 0.1 + yita: 0.5 + percent: 0.1 + gso_type: sym_norm_lap + graph_conv_type: cheb_graph_conv + +train: + loss_func: mae + seed: 10 + batch_size: 64 + epochs: 300 + lr_init: 0.003 + weight_decay: 0 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + +test: + mae_thresh: null + mape_thresh: 0.0 + +log: + log_step: 200 + plot: False + + diff --git a/model/model_selector.py b/model/model_selector.py index 796c814..9cb75f8 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -18,6 +18,8 @@ from model.STIDGCN.STIDGCN import STIDGCN from model.STID.STID import STID from model.STAEFormer.STAEFormer import STAEformer from model.EXP.EXP32 import EXP as EXP +from model.MegaCRN.MegaCRNModel import MegaCRNModel +from model.ST_SSL.ST_SSL import STSSLModel def model_selector(model): match model['type']: @@ -41,4 +43,6 @@ def model_selector(model): case 'STID': return STID(model) case 'STAEFormer': return STAEformer(model) case 'EXP': return EXP(model) + case 'MegaCRN': return MegaCRNModel(model) + case 'ST_SSL': return STSSLModel(model)