From 9b3bb44552a4b5f4e1cf2da81752c19704f80e00 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 22:30:23 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=85=8D=E7=BD=AE,=20trainer?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 15 +- .vscode/settings.json | 5 + config/AGCRN/PEMSD3.yaml | 6 + config/AGCRN/PEMSD4.yaml | 6 + config/AGCRN/PEMSD7.yaml | 6 + config/AGCRN/PEMSD8.yaml | 6 + config/ARIMA/Hainan.yaml | 6 + config/ARIMA/PEMSD3.yaml | 6 + config/ARIMA/PEMSD4.yaml | 6 + config/ARIMA/PEMSD7(L).yaml | 6 + config/ARIMA/PEMSD7(M).yaml | 6 + config/ARIMA/PEMSD7.yaml | 6 + config/ARIMA/PEMSD8.yaml | 6 + config/DCRNN/PEMSD3.yaml | 6 + config/DCRNN/PEMSD4.yaml | 6 + config/DCRNN/PEMSD7.yaml | 6 + config/DCRNN/PEMSD8.yaml | 6 + config/DDGCRN/Hainan.yaml | 6 + config/DDGCRN/PEMSD3.yaml | 6 + config/DDGCRN/PEMSD4.yaml | 6 + config/DDGCRN/PEMSD7(L).yaml | 6 + config/DDGCRN/PEMSD7(M).yaml | 6 + config/DDGCRN/PEMSD7.yaml | 6 + config/DSANET/PEMSD3.yaml | 6 + config/DSANET/PEMSD4.yaml | 6 + config/DSANET/PEMSD7.yaml | 6 + config/DSANET/PEMSD8.yaml | 6 + config/EXP/PEMSD3.yaml | 6 + config/EXP/PEMSD4.yaml | 6 + config/EXP/PEMSD7.yaml | 6 + config/EXP/SD.yaml | 6 + config/EXPB/PEMSD4.yaml | 6 + config/GWN/PEMSD3.yaml | 6 + config/GWN/PEMSD4.yaml | 6 + config/GWN/PEMSD7.yaml | 6 + config/GWN/PEMSD8.yaml | 6 + config/MegaCRN/PEMSD3.yaml | 6 + config/MegaCRN/PEMSD4.yaml | 6 + config/MegaCRN/PEMSD7.yaml | 6 + config/MegaCRN/PEMSD8.yaml | 6 + config/NLT/PEMSD3.yaml | 6 + config/NLT/PEMSD4.yaml | 6 + config/NLT/PEMSD7.yaml | 6 + config/NLT/PEMSD8.yaml | 6 + config/PDG2SEQ/PEMSD3.yaml | 6 + config/PDG2SEQ/PEMSD4.yaml | 6 + config/PDG2SEQ/PEMSD7.yaml | 6 + config/PDG2SEQ/PEMSD8.yaml | 6 + config/STAEFormer/PEMSD3.yaml | 6 + config/STAEFormer/PEMSD4.yaml | 6 + config/STAEFormer/PEMSD7.yaml | 6 + config/STAEFormer/PEMSD8.yaml | 6 + config/STAWnet/PEMSD3.yaml | 6 + config/STAWnet/PEMSD4.yaml | 6 + config/STAWnet/PEMSD7.yaml | 6 + config/STAWnet/PEMSD8.yaml | 6 + config/STFGNN/PEMSD3.yaml | 6 + config/STFGNN/PEMSD4.yaml | 6 + config/STFGNN/PEMSD7.yaml | 6 + config/STFGNN/PEMSD8.yaml | 6 + config/STGCN/PEMSD3.yaml | 6 + config/STGCN/PEMSD4.yaml | 6 + config/STGCN/PEMSD7.yaml | 6 + config/STGCN/PEMSD8.yaml | 6 + config/STGNCDE/PEMSD3.yaml | 6 + config/STGNCDE/PEMSD4.yaml | 6 + config/STGNCDE/PEMSD7.yaml | 6 + config/STGNCDE/PEMSD8.yaml | 6 + config/STGNRDE/PEMSD3.yaml | 6 + config/STGNRDE/PEMSD4.yaml | 6 + config/STGNRDE/PEMSD7.yaml | 6 + config/STGNRDE/PEMSD8.yaml | 6 + config/STGODE/PEMSD3.yaml | 6 + config/STGODE/PEMSD4.yaml | 6 + config/STGODE/PEMSD7.yaml | 6 + config/STGODE/PEMSD8.yaml | 6 + config/STID/PEMSD4.yaml | 6 + config/STIDGCN/PEMSD3.yaml | 6 + config/STIDGCN/PEMSD4.yaml | 6 + config/STIDGCN/PEMSD7.yaml | 6 + config/STIDGCN/PEMSD8.yaml | 6 + config/STMLP/PEMSD3.yaml | 6 + config/STMLP/PEMSD4.yaml | 6 + config/STMLP/PEMSD7.yaml | 6 + config/STMLP/PEMSD8.yaml | 6 + config/STSGCN/PEMSD3.yaml | 6 + config/STSGCN/PEMSD4.yaml | 6 + config/STSGCN/PEMSD7.yaml | 6 + config/STSGCN/PEMSD8.yaml | 6 + config/ST_SSL/PEMSD3.yaml | 6 + config/ST_SSL/PEMSD4.yaml | 6 + config/ST_SSL/PEMSD7.yaml | 6 + config/ST_SSL/PEMSD8.yaml | 6 + config/TCN/PEMSD3.yaml | 6 + config/TCN/PEMSD4.yaml | 6 + config/TCN/PEMSD7.yaml | 6 + config/TCN/PEMSD8.yaml | 6 + config/TWDGCN/Hainan.yaml | 6 + config/TWDGCN/PEMSD3.yaml | 6 + config/TWDGCN/PEMSD4.yaml | 6 + config/TWDGCN/PEMSD7(L).yaml | 6 + config/TWDGCN/PEMSD7(M).yaml | 6 + config/TWDGCN/PEMSD7.yaml | 6 + config/TWDGCN/PEMSD8.yaml | 6 + dataloader/PeMSDdataloader.py | 3 - dataloader/data_selector.py | 5 +- model/DDGCRN/DDGCRN_old.py | 298 ---------------------------------- model/STID/STID.py | 84 +++------- trainer/Trainer.py | 9 +- 109 files changed, 664 insertions(+), 367 deletions(-) create mode 100644 .vscode/settings.json delete mode 100755 model/DDGCRN/DDGCRN_old.py diff --git a/.vscode/launch.json b/.vscode/launch.json index 752af1d..ebce381 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,14 +4,21 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - - { - "name": "EXP_PEMSD8", +{ + "name": "STID_PEMS-BAY", "type": "debugpy", "request": "launch", "program": "run.py", "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD8.yaml" + "args": "--config ./config/STID/PEMS-BAY.yaml" + }, + { + "name": "STID_PEMSD4", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/PEMSD4.yaml" }, { "name": "REPST", diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..a8c2003 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,5 @@ +{ + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [] +} \ No newline at end of file diff --git a/config/AGCRN/PEMSD3.yaml b/config/AGCRN/PEMSD3.yaml index 1a28613..92db310 100755 --- a/config/AGCRN/PEMSD3.yaml +++ b/config/AGCRN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "AGCRN" + data: num_nodes: 358 lag: 12 diff --git a/config/AGCRN/PEMSD4.yaml b/config/AGCRN/PEMSD4.yaml index d838c28..9a3b04e 100755 --- a/config/AGCRN/PEMSD4.yaml +++ b/config/AGCRN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "AGCRN" + data: num_nodes: 307 lag: 12 diff --git a/config/AGCRN/PEMSD7.yaml b/config/AGCRN/PEMSD7.yaml index 9f5a0ee..4aeaf43 100755 --- a/config/AGCRN/PEMSD7.yaml +++ b/config/AGCRN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "AGCRN" + data: num_nodes: 883 lag: 12 diff --git a/config/AGCRN/PEMSD8.yaml b/config/AGCRN/PEMSD8.yaml index 118b428..eb96c35 100755 --- a/config/AGCRN/PEMSD8.yaml +++ b/config/AGCRN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "AGCRN" + data: num_nodes: 170 lag: 12 diff --git a/config/ARIMA/Hainan.yaml b/config/ARIMA/Hainan.yaml index d75ad6b..f084cf0 100755 --- a/config/ARIMA/Hainan.yaml +++ b/config/ARIMA/Hainan.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "Hainan" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: num_nodes: 13 lag: 12 diff --git a/config/ARIMA/PEMSD3.yaml b/config/ARIMA/PEMSD3.yaml index 9038770..3068051 100755 --- a/config/ARIMA/PEMSD3.yaml +++ b/config/ARIMA/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: num_nodes: 358 lag: 12 diff --git a/config/ARIMA/PEMSD4.yaml b/config/ARIMA/PEMSD4.yaml index e93bf84..2e0198c 100755 --- a/config/ARIMA/PEMSD4.yaml +++ b/config/ARIMA/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: num_nodes: 307 lag: 12 diff --git a/config/ARIMA/PEMSD7(L).yaml b/config/ARIMA/PEMSD7(L).yaml index 1994680..be0ca07 100755 --- a/config/ARIMA/PEMSD7(L).yaml +++ b/config/ARIMA/PEMSD7(L).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(L)" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/ARIMA/PEMSD7(M).yaml b/config/ARIMA/PEMSD7(M).yaml index c70a065..6f5d7cf 100755 --- a/config/ARIMA/PEMSD7(M).yaml +++ b/config/ARIMA/PEMSD7(M).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(M)" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/ARIMA/PEMSD7.yaml b/config/ARIMA/PEMSD7.yaml index becc9a0..f3ab397 100755 --- a/config/ARIMA/PEMSD7.yaml +++ b/config/ARIMA/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/ARIMA/PEMSD8.yaml b/config/ARIMA/PEMSD8.yaml index bb590c9..c83e25a 100755 --- a/config/ARIMA/PEMSD8.yaml +++ b/config/ARIMA/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "ARIMA" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/DCRNN/PEMSD3.yaml b/config/DCRNN/PEMSD3.yaml index a13d673..84ba071 100755 --- a/config/DCRNN/PEMSD3.yaml +++ b/config/DCRNN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "DCRNN" + data: num_nodes: 358 lag: 12 diff --git a/config/DCRNN/PEMSD4.yaml b/config/DCRNN/PEMSD4.yaml index 8758786..37f057d 100755 --- a/config/DCRNN/PEMSD4.yaml +++ b/config/DCRNN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "DCRNN" + data: num_nodes: 307 lag: 12 diff --git a/config/DCRNN/PEMSD7.yaml b/config/DCRNN/PEMSD7.yaml index e771f53..74d3790 100755 --- a/config/DCRNN/PEMSD7.yaml +++ b/config/DCRNN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "DCRNN" + data: num_nodes: 883 lag: 12 diff --git a/config/DCRNN/PEMSD8.yaml b/config/DCRNN/PEMSD8.yaml index ff33ef1..781bfa7 100755 --- a/config/DCRNN/PEMSD8.yaml +++ b/config/DCRNN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "DCRNN" + data: num_nodes: 170 lag: 12 diff --git a/config/DDGCRN/Hainan.yaml b/config/DDGCRN/Hainan.yaml index 7bbfa3b..0f7b7bf 100755 --- a/config/DDGCRN/Hainan.yaml +++ b/config/DDGCRN/Hainan.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "Hainan" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: num_nodes: 13 lag: 12 diff --git a/config/DDGCRN/PEMSD3.yaml b/config/DDGCRN/PEMSD3.yaml index 12854a7..9ecf9b4 100755 --- a/config/DDGCRN/PEMSD3.yaml +++ b/config/DDGCRN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: num_nodes: 358 lag: 12 diff --git a/config/DDGCRN/PEMSD4.yaml b/config/DDGCRN/PEMSD4.yaml index 2081849..417fdea 100755 --- a/config/DDGCRN/PEMSD4.yaml +++ b/config/DDGCRN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: num_nodes: 307 lag: 12 diff --git a/config/DDGCRN/PEMSD7(L).yaml b/config/DDGCRN/PEMSD7(L).yaml index 1994680..a5c4156 100755 --- a/config/DDGCRN/PEMSD7(L).yaml +++ b/config/DDGCRN/PEMSD7(L).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(L)" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/DDGCRN/PEMSD7(M).yaml b/config/DDGCRN/PEMSD7(M).yaml index c70a065..dacedf5 100755 --- a/config/DDGCRN/PEMSD7(M).yaml +++ b/config/DDGCRN/PEMSD7(M).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(M)" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/DDGCRN/PEMSD7.yaml b/config/DDGCRN/PEMSD7.yaml index c006075..24ef3b4 100755 --- a/config/DDGCRN/PEMSD7.yaml +++ b/config/DDGCRN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "DDGCRN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/DSANET/PEMSD3.yaml b/config/DSANET/PEMSD3.yaml index c87e1ef..d3e705a 100755 --- a/config/DSANET/PEMSD3.yaml +++ b/config/DSANET/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "DSANET" + data: num_nodes: 358 lag: 12 diff --git a/config/DSANET/PEMSD4.yaml b/config/DSANET/PEMSD4.yaml index 3a8c533..ca92018 100755 --- a/config/DSANET/PEMSD4.yaml +++ b/config/DSANET/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "DSANET" + data: num_nodes: 307 lag: 12 diff --git a/config/DSANET/PEMSD7.yaml b/config/DSANET/PEMSD7.yaml index f10b5fe..ba02c7d 100755 --- a/config/DSANET/PEMSD7.yaml +++ b/config/DSANET/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "DSANET" + data: num_nodes: 883 lag: 12 diff --git a/config/DSANET/PEMSD8.yaml b/config/DSANET/PEMSD8.yaml index 4eb9219..efd8bcd 100755 --- a/config/DSANET/PEMSD8.yaml +++ b/config/DSANET/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "DSANET" + data: num_nodes: 170 lag: 12 diff --git a/config/EXP/PEMSD3.yaml b/config/EXP/PEMSD3.yaml index 40fadfa..d6036df 100755 --- a/config/EXP/PEMSD3.yaml +++ b/config/EXP/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "EXP" + data: diff --git a/config/EXP/PEMSD4.yaml b/config/EXP/PEMSD4.yaml index 8c15601..2f1b70a 100755 --- a/config/EXP/PEMSD4.yaml +++ b/config/EXP/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "EXP" + data: num_nodes: 307 lag: 12 diff --git a/config/EXP/PEMSD7.yaml b/config/EXP/PEMSD7.yaml index 430e1bb..8f7d41e 100755 --- a/config/EXP/PEMSD7.yaml +++ b/config/EXP/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "EXP" + data: num_nodes: 883 lag: 12 diff --git a/config/EXP/SD.yaml b/config/EXP/SD.yaml index 977b2e8..2b4d178 100755 --- a/config/EXP/SD.yaml +++ b/config/EXP/SD.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "SD" + mode: "train" + device: "cuda:0" + model: "EXP" + data: num_nodes: 716 lag: 12 diff --git a/config/EXPB/PEMSD4.yaml b/config/EXPB/PEMSD4.yaml index dae7923..ecfa912 100755 --- a/config/EXPB/PEMSD4.yaml +++ b/config/EXPB/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "EXPB" + data: num_nodes: 307 lag: 12 diff --git a/config/GWN/PEMSD3.yaml b/config/GWN/PEMSD3.yaml index 529a928..4c3bafc 100755 --- a/config/GWN/PEMSD3.yaml +++ b/config/GWN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "GWN" + data: num_nodes: 358 lag: 12 diff --git a/config/GWN/PEMSD4.yaml b/config/GWN/PEMSD4.yaml index fa84d8e..b6b7a4d 100755 --- a/config/GWN/PEMSD4.yaml +++ b/config/GWN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "GWN" + data: num_nodes: 307 lag: 12 diff --git a/config/GWN/PEMSD7.yaml b/config/GWN/PEMSD7.yaml index f8f675b..09ef25c 100755 --- a/config/GWN/PEMSD7.yaml +++ b/config/GWN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "GWN" + data: num_nodes: 883 lag: 12 diff --git a/config/GWN/PEMSD8.yaml b/config/GWN/PEMSD8.yaml index 39b95e3..1912b9b 100755 --- a/config/GWN/PEMSD8.yaml +++ b/config/GWN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "GWN" + data: num_nodes: 170 lag: 12 diff --git a/config/MegaCRN/PEMSD3.yaml b/config/MegaCRN/PEMSD3.yaml index 4d90c7c..19c6d87 100644 --- a/config/MegaCRN/PEMSD3.yaml +++ b/config/MegaCRN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "MegaCRN" + data: num_nodes: 358 lag: 12 diff --git a/config/MegaCRN/PEMSD4.yaml b/config/MegaCRN/PEMSD4.yaml index cebb3be..33b26f7 100644 --- a/config/MegaCRN/PEMSD4.yaml +++ b/config/MegaCRN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "MegaCRN" + data: num_nodes: 307 lag: 12 diff --git a/config/MegaCRN/PEMSD7.yaml b/config/MegaCRN/PEMSD7.yaml index 965ef14..6dbd633 100644 --- a/config/MegaCRN/PEMSD7.yaml +++ b/config/MegaCRN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "MegaCRN" + data: num_nodes: 883 lag: 12 diff --git a/config/MegaCRN/PEMSD8.yaml b/config/MegaCRN/PEMSD8.yaml index 3d00c33..254d0fe 100644 --- a/config/MegaCRN/PEMSD8.yaml +++ b/config/MegaCRN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "MegaCRN" + data: num_nodes: 170 lag: 12 diff --git a/config/NLT/PEMSD3.yaml b/config/NLT/PEMSD3.yaml index 6093ebc..c59cfbe 100755 --- a/config/NLT/PEMSD3.yaml +++ b/config/NLT/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "NLT" + data: num_nodes: 358 lag: 12 diff --git a/config/NLT/PEMSD4.yaml b/config/NLT/PEMSD4.yaml index ee576f5..46a7ad7 100755 --- a/config/NLT/PEMSD4.yaml +++ b/config/NLT/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "NLT" + data: num_nodes: 307 lag: 12 diff --git a/config/NLT/PEMSD7.yaml b/config/NLT/PEMSD7.yaml index c671bd2..692b76b 100755 --- a/config/NLT/PEMSD7.yaml +++ b/config/NLT/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "NLT" + data: num_nodes: 883 lag: 12 diff --git a/config/NLT/PEMSD8.yaml b/config/NLT/PEMSD8.yaml index 91f3581..5828019 100755 --- a/config/NLT/PEMSD8.yaml +++ b/config/NLT/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "NLT" + data: num_nodes: 170 lag: 12 diff --git a/config/PDG2SEQ/PEMSD3.yaml b/config/PDG2SEQ/PEMSD3.yaml index 438ba5c..5d1e619 100755 --- a/config/PDG2SEQ/PEMSD3.yaml +++ b/config/PDG2SEQ/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "PDG2SEQ" + data: num_nodes: 358 lag: 12 diff --git a/config/PDG2SEQ/PEMSD4.yaml b/config/PDG2SEQ/PEMSD4.yaml index 687054d..84b9394 100755 --- a/config/PDG2SEQ/PEMSD4.yaml +++ b/config/PDG2SEQ/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "PDG2SEQ" + data: num_nodes: 307 lag: 12 diff --git a/config/PDG2SEQ/PEMSD7.yaml b/config/PDG2SEQ/PEMSD7.yaml index 1b21274..06d4cb6 100755 --- a/config/PDG2SEQ/PEMSD7.yaml +++ b/config/PDG2SEQ/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "PDG2SEQ" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/PDG2SEQ/PEMSD8.yaml b/config/PDG2SEQ/PEMSD8.yaml index c1f7f6e..3a3b681 100755 --- a/config/PDG2SEQ/PEMSD8.yaml +++ b/config/PDG2SEQ/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "PDG2SEQ" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/STAEFormer/PEMSD3.yaml b/config/STAEFormer/PEMSD3.yaml index 2014b44..819b986 100755 --- a/config/STAEFormer/PEMSD3.yaml +++ b/config/STAEFormer/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STAEFormer" + data: num_nodes: 358 lag: 12 diff --git a/config/STAEFormer/PEMSD4.yaml b/config/STAEFormer/PEMSD4.yaml index bb94654..240e8db 100755 --- a/config/STAEFormer/PEMSD4.yaml +++ b/config/STAEFormer/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STAEFormer" + data: num_nodes: 307 lag: 12 diff --git a/config/STAEFormer/PEMSD7.yaml b/config/STAEFormer/PEMSD7.yaml index a593bd3..f11545e 100755 --- a/config/STAEFormer/PEMSD7.yaml +++ b/config/STAEFormer/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STAEFormer" + data: num_nodes: 883 lag: 12 diff --git a/config/STAEFormer/PEMSD8.yaml b/config/STAEFormer/PEMSD8.yaml index f625d5d..791ba1d 100755 --- a/config/STAEFormer/PEMSD8.yaml +++ b/config/STAEFormer/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STAEFormer" + data: num_nodes: 170 lag: 12 diff --git a/config/STAWnet/PEMSD3.yaml b/config/STAWnet/PEMSD3.yaml index d54087e..ff48317 100644 --- a/config/STAWnet/PEMSD3.yaml +++ b/config/STAWnet/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STAWnet" + data: num_nodes: 358 lag: 12 diff --git a/config/STAWnet/PEMSD4.yaml b/config/STAWnet/PEMSD4.yaml index b3346b1..5c0aa37 100644 --- a/config/STAWnet/PEMSD4.yaml +++ b/config/STAWnet/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STAWnet" + data: num_nodes: 307 lag: 12 diff --git a/config/STAWnet/PEMSD7.yaml b/config/STAWnet/PEMSD7.yaml index a314c04..d58125e 100644 --- a/config/STAWnet/PEMSD7.yaml +++ b/config/STAWnet/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STAWnet" + data: num_nodes: 883 lag: 12 diff --git a/config/STAWnet/PEMSD8.yaml b/config/STAWnet/PEMSD8.yaml index 75beec1..ccb4619 100644 --- a/config/STAWnet/PEMSD8.yaml +++ b/config/STAWnet/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STAWnet" + data: num_nodes: 170 lag: 12 diff --git a/config/STFGNN/PEMSD3.yaml b/config/STFGNN/PEMSD3.yaml index a8184d8..d8c09fd 100755 --- a/config/STFGNN/PEMSD3.yaml +++ b/config/STFGNN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STFGNN" + data: num_nodes: 358 lag: 12 diff --git a/config/STFGNN/PEMSD4.yaml b/config/STFGNN/PEMSD4.yaml index 3f9e26f..3903082 100755 --- a/config/STFGNN/PEMSD4.yaml +++ b/config/STFGNN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STFGNN" + data: num_nodes: 307 lag: 12 diff --git a/config/STFGNN/PEMSD7.yaml b/config/STFGNN/PEMSD7.yaml index 0092907..bd211ec 100755 --- a/config/STFGNN/PEMSD7.yaml +++ b/config/STFGNN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STFGNN" + data: num_nodes: 883 lag: 12 diff --git a/config/STFGNN/PEMSD8.yaml b/config/STFGNN/PEMSD8.yaml index f9bca0f..9f8b022 100755 --- a/config/STFGNN/PEMSD8.yaml +++ b/config/STFGNN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STFGNN" + data: num_nodes: 170 lag: 12 diff --git a/config/STGCN/PEMSD3.yaml b/config/STGCN/PEMSD3.yaml index af9d5a4..51f22fc 100755 --- a/config/STGCN/PEMSD3.yaml +++ b/config/STGCN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STGCN" + data: num_nodes: 358 lag: 12 diff --git a/config/STGCN/PEMSD4.yaml b/config/STGCN/PEMSD4.yaml index 1fc5146..64164cc 100755 --- a/config/STGCN/PEMSD4.yaml +++ b/config/STGCN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STGCN" + data: num_nodes: 307 lag: 12 diff --git a/config/STGCN/PEMSD7.yaml b/config/STGCN/PEMSD7.yaml index 5163e26..8b241f9 100755 --- a/config/STGCN/PEMSD7.yaml +++ b/config/STGCN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STGCN" + data: num_nodes: 883 lag: 12 diff --git a/config/STGCN/PEMSD8.yaml b/config/STGCN/PEMSD8.yaml index 9361926..8f2fc4d 100755 --- a/config/STGCN/PEMSD8.yaml +++ b/config/STGCN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STGCN" + data: num_nodes: 170 lag: 12 diff --git a/config/STGNCDE/PEMSD3.yaml b/config/STGNCDE/PEMSD3.yaml index c0c8f96..5aaec14 100755 --- a/config/STGNCDE/PEMSD3.yaml +++ b/config/STGNCDE/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STGNCDE" + data: num_nodes: 358 lag: 12 diff --git a/config/STGNCDE/PEMSD4.yaml b/config/STGNCDE/PEMSD4.yaml index e729bcd..7ae0964 100755 --- a/config/STGNCDE/PEMSD4.yaml +++ b/config/STGNCDE/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STGNCDE" + data: num_nodes: 307 lag: 12 diff --git a/config/STGNCDE/PEMSD7.yaml b/config/STGNCDE/PEMSD7.yaml index f7d2d58..1761fc8 100755 --- a/config/STGNCDE/PEMSD7.yaml +++ b/config/STGNCDE/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STGNCDE" + data: num_nodes: 883 lag: 12 diff --git a/config/STGNCDE/PEMSD8.yaml b/config/STGNCDE/PEMSD8.yaml index fae9ed7..7da837c 100755 --- a/config/STGNCDE/PEMSD8.yaml +++ b/config/STGNCDE/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STGNCDE" + data: num_nodes: 170 lag: 12 diff --git a/config/STGNRDE/PEMSD3.yaml b/config/STGNRDE/PEMSD3.yaml index 3d98b05..1199815 100644 --- a/config/STGNRDE/PEMSD3.yaml +++ b/config/STGNRDE/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STGNRDE" + data: num_nodes: 358 lag: 12 diff --git a/config/STGNRDE/PEMSD4.yaml b/config/STGNRDE/PEMSD4.yaml index c75617c..f04b1c0 100644 --- a/config/STGNRDE/PEMSD4.yaml +++ b/config/STGNRDE/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STGNRDE" + data: num_nodes: 307 lag: 12 diff --git a/config/STGNRDE/PEMSD7.yaml b/config/STGNRDE/PEMSD7.yaml index e135a63..4808d20 100644 --- a/config/STGNRDE/PEMSD7.yaml +++ b/config/STGNRDE/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STGNRDE" + data: num_nodes: 883 lag: 12 diff --git a/config/STGNRDE/PEMSD8.yaml b/config/STGNRDE/PEMSD8.yaml index 3c0751b..5521dc7 100644 --- a/config/STGNRDE/PEMSD8.yaml +++ b/config/STGNRDE/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STGNRDE" + data: num_nodes: 170 lag: 12 diff --git a/config/STGODE/PEMSD3.yaml b/config/STGODE/PEMSD3.yaml index d5d7d99..ae31781 100755 --- a/config/STGODE/PEMSD3.yaml +++ b/config/STGODE/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STGODE" + data: num_nodes: 358 lag: 12 diff --git a/config/STGODE/PEMSD4.yaml b/config/STGODE/PEMSD4.yaml index 4a09717..3caa8c8 100755 --- a/config/STGODE/PEMSD4.yaml +++ b/config/STGODE/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STGODE" + data: num_nodes: 307 lag: 12 diff --git a/config/STGODE/PEMSD7.yaml b/config/STGODE/PEMSD7.yaml index b9d34e9..c868c90 100755 --- a/config/STGODE/PEMSD7.yaml +++ b/config/STGODE/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STGODE" + data: num_nodes: 883 lag: 12 diff --git a/config/STGODE/PEMSD8.yaml b/config/STGODE/PEMSD8.yaml index c0c5a07..26ef07b 100755 --- a/config/STGODE/PEMSD8.yaml +++ b/config/STGODE/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STGODE" + data: num_nodes: 170 lag: 12 diff --git a/config/STID/PEMSD4.yaml b/config/STID/PEMSD4.yaml index 8cf6ba5..dfb0726 100755 --- a/config/STID/PEMSD4.yaml +++ b/config/STID/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STID" + data: num_nodes: 307 lag: 12 diff --git a/config/STIDGCN/PEMSD3.yaml b/config/STIDGCN/PEMSD3.yaml index c3d8ce2..139d014 100644 --- a/config/STIDGCN/PEMSD3.yaml +++ b/config/STIDGCN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STIDGCN" + data: num_nodes: 358 lag: 12 diff --git a/config/STIDGCN/PEMSD4.yaml b/config/STIDGCN/PEMSD4.yaml index fb4dae6..8509b54 100644 --- a/config/STIDGCN/PEMSD4.yaml +++ b/config/STIDGCN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STIDGCN" + data: num_nodes: 307 lag: 12 diff --git a/config/STIDGCN/PEMSD7.yaml b/config/STIDGCN/PEMSD7.yaml index 4f67ce1..f79f49f 100644 --- a/config/STIDGCN/PEMSD7.yaml +++ b/config/STIDGCN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STIDGCN" + data: num_nodes: 883 lag: 12 diff --git a/config/STIDGCN/PEMSD8.yaml b/config/STIDGCN/PEMSD8.yaml index 354c9e3..3bdccd3 100644 --- a/config/STIDGCN/PEMSD8.yaml +++ b/config/STIDGCN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STIDGCN" + data: num_nodes: 170 lag: 12 diff --git a/config/STMLP/PEMSD3.yaml b/config/STMLP/PEMSD3.yaml index eee7a15..700dbaa 100644 --- a/config/STMLP/PEMSD3.yaml +++ b/config/STMLP/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STMLP" + data: num_nodes: 358 lag: 12 diff --git a/config/STMLP/PEMSD4.yaml b/config/STMLP/PEMSD4.yaml index c416fc4..6240857 100644 --- a/config/STMLP/PEMSD4.yaml +++ b/config/STMLP/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STMLP" + data: num_nodes: 307 lag: 12 diff --git a/config/STMLP/PEMSD7.yaml b/config/STMLP/PEMSD7.yaml index 14e6382..8080394 100644 --- a/config/STMLP/PEMSD7.yaml +++ b/config/STMLP/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STMLP" + data: num_nodes: 883 lag: 12 diff --git a/config/STMLP/PEMSD8.yaml b/config/STMLP/PEMSD8.yaml index bceffa5..c1fb16b 100644 --- a/config/STMLP/PEMSD8.yaml +++ b/config/STMLP/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STMLP" + data: num_nodes: 170 lag: 12 diff --git a/config/STSGCN/PEMSD3.yaml b/config/STSGCN/PEMSD3.yaml index 8d37e2f..00a740c 100755 --- a/config/STSGCN/PEMSD3.yaml +++ b/config/STSGCN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "STSGCN" + data: num_nodes: 358 lag: 12 diff --git a/config/STSGCN/PEMSD4.yaml b/config/STSGCN/PEMSD4.yaml index f1d9ca1..e6e159f 100755 --- a/config/STSGCN/PEMSD4.yaml +++ b/config/STSGCN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "STSGCN" + data: num_nodes: 307 lag: 12 diff --git a/config/STSGCN/PEMSD7.yaml b/config/STSGCN/PEMSD7.yaml index 8f3148e..34dc0b0 100755 --- a/config/STSGCN/PEMSD7.yaml +++ b/config/STSGCN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "STSGCN" + data: num_nodes: 883 lag: 12 diff --git a/config/STSGCN/PEMSD8.yaml b/config/STSGCN/PEMSD8.yaml index 6567ac3..b5267b6 100755 --- a/config/STSGCN/PEMSD8.yaml +++ b/config/STSGCN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "STSGCN" + data: num_nodes: 170 lag: 12 diff --git a/config/ST_SSL/PEMSD3.yaml b/config/ST_SSL/PEMSD3.yaml index 3c2f488..0ca64d5 100644 --- a/config/ST_SSL/PEMSD3.yaml +++ b/config/ST_SSL/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "ST_SSL" + data: num_nodes: 358 lag: 12 diff --git a/config/ST_SSL/PEMSD4.yaml b/config/ST_SSL/PEMSD4.yaml index cf44a11..03178ff 100644 --- a/config/ST_SSL/PEMSD4.yaml +++ b/config/ST_SSL/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "ST_SSL" + data: num_nodes: 307 lag: 12 diff --git a/config/ST_SSL/PEMSD7.yaml b/config/ST_SSL/PEMSD7.yaml index 1c0fe31..c952ed1 100644 --- a/config/ST_SSL/PEMSD7.yaml +++ b/config/ST_SSL/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "ST_SSL" + data: num_nodes: 883 lag: 12 diff --git a/config/ST_SSL/PEMSD8.yaml b/config/ST_SSL/PEMSD8.yaml index f9440e3..b057328 100644 --- a/config/ST_SSL/PEMSD8.yaml +++ b/config/ST_SSL/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "ST_SSL" + data: num_nodes: 170 lag: 12 diff --git a/config/TCN/PEMSD3.yaml b/config/TCN/PEMSD3.yaml index b05e6e3..ae281dd 100755 --- a/config/TCN/PEMSD3.yaml +++ b/config/TCN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "TCN" + data: num_nodes: 358 lag: 12 diff --git a/config/TCN/PEMSD4.yaml b/config/TCN/PEMSD4.yaml index ff37ea3..f5dbc3f 100755 --- a/config/TCN/PEMSD4.yaml +++ b/config/TCN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "TCN" + data: num_nodes: 307 lag: 12 diff --git a/config/TCN/PEMSD7.yaml b/config/TCN/PEMSD7.yaml index 26c04b9..99e7d4e 100755 --- a/config/TCN/PEMSD7.yaml +++ b/config/TCN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "TCN" + data: num_nodes: 883 lag: 12 diff --git a/config/TCN/PEMSD8.yaml b/config/TCN/PEMSD8.yaml index cf20985..006be57 100755 --- a/config/TCN/PEMSD8.yaml +++ b/config/TCN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "TCN" + data: num_nodes: 170 lag: 12 diff --git a/config/TWDGCN/Hainan.yaml b/config/TWDGCN/Hainan.yaml index 7bbfa3b..3fea40b 100755 --- a/config/TWDGCN/Hainan.yaml +++ b/config/TWDGCN/Hainan.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "Hainan" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: num_nodes: 13 lag: 12 diff --git a/config/TWDGCN/PEMSD3.yaml b/config/TWDGCN/PEMSD3.yaml index 7f0eb9b..646365d 100755 --- a/config/TWDGCN/PEMSD3.yaml +++ b/config/TWDGCN/PEMSD3.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD3" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: num_nodes: 358 lag: 12 diff --git a/config/TWDGCN/PEMSD4.yaml b/config/TWDGCN/PEMSD4.yaml index c3cbf12..38910e8 100755 --- a/config/TWDGCN/PEMSD4.yaml +++ b/config/TWDGCN/PEMSD4.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD4" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: num_nodes: 307 lag: 12 diff --git a/config/TWDGCN/PEMSD7(L).yaml b/config/TWDGCN/PEMSD7(L).yaml index 1994680..2c725cc 100755 --- a/config/TWDGCN/PEMSD7(L).yaml +++ b/config/TWDGCN/PEMSD7(L).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(L)" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/TWDGCN/PEMSD7(M).yaml b/config/TWDGCN/PEMSD7(M).yaml index c70a065..d8ef8dd 100755 --- a/config/TWDGCN/PEMSD7(M).yaml +++ b/config/TWDGCN/PEMSD7(M).yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7(M)" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/TWDGCN/PEMSD7.yaml b/config/TWDGCN/PEMSD7.yaml index 1ba5a0d..14590af 100755 --- a/config/TWDGCN/PEMSD7.yaml +++ b/config/TWDGCN/PEMSD7.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD7" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: add_day_in_week: true add_time_in_day: true diff --git a/config/TWDGCN/PEMSD8.yaml b/config/TWDGCN/PEMSD8.yaml index 33cd596..e03ba8a 100755 --- a/config/TWDGCN/PEMSD8.yaml +++ b/config/TWDGCN/PEMSD8.yaml @@ -1,3 +1,9 @@ +basic: + dataset: "PEMSD8" + mode: "train" + device: "cuda:0" + model: "TWDGCN" + data: add_day_in_week: true add_time_in_day: true diff --git a/dataloader/PeMSDdataloader.py b/dataloader/PeMSDdataloader.py index 7c42123..720b5c3 100755 --- a/dataloader/PeMSDdataloader.py +++ b/dataloader/PeMSDdataloader.py @@ -86,7 +86,6 @@ def get_dataloader(args, normalizer="std", single=True): gc.collect() # Step 5: x_train y_train x_val y_val x_test y_test --> train val test - # train_dataloader = data_loader(x_train[..., :args['input_dim']], y_train[..., :args['input_dim']], args['batch_size'], shuffle=True, drop_last=True) train_dataloader = data_loader( x_train, y_train, args["batch_size"], shuffle=True, drop_last=True ) @@ -94,7 +93,6 @@ def get_dataloader(args, normalizer="std", single=True): del x_train, y_train gc.collect() - # val_dataloader = data_loader(x_val[..., :args['input_dim']], y_val[..., :args['input_dim']], args['batch_size'], shuffle=False, drop_last=True) val_dataloader = data_loader( x_val, y_val, args["batch_size"], shuffle=False, drop_last=True ) @@ -102,7 +100,6 @@ def get_dataloader(args, normalizer="std", single=True): del x_val, y_val gc.collect() - # test_dataloader = data_loader(x_test[..., :args['input_dim']], y_test[..., :args['input_dim']], args['batch_size'], shuffle=False, drop_last=False) test_dataloader = data_loader( x_test, y_test, args["batch_size"], shuffle=False, drop_last=False ) diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index bc5d609..19fe7f5 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -4,7 +4,7 @@ import h5py def load_st_dataset(config): dataset = config["basic"]["dataset"] - sample = config["data"]["sample"] + # sample = config["data"]["sample"] # output B, N, D match dataset: case "PEMS-BAY": @@ -66,4 +66,5 @@ def load_st_dataset(config): data = np.expand_dims(data, axis=-1) print("加载 %s 数据集中... " % dataset) - return data[::sample] + # return data[::sample] + return data diff --git a/model/DDGCRN/DDGCRN_old.py b/model/DDGCRN/DDGCRN_old.py deleted file mode 100755 index bacc038..0000000 --- a/model/DDGCRN/DDGCRN_old.py +++ /dev/null @@ -1,298 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F -from collections import OrderedDict - - -class DGCRM(nn.Module): - def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim, num_layers=1): - super(DGCRM, self).__init__() - assert num_layers >= 1, "At least one DGCRM layer is required." - - self.node_num = node_num - self.input_dim = dim_in - self.num_layers = num_layers - - # Initialize DGCRM cells - self.DGCRM_cells = nn.ModuleList( - [ - DDGCRNCell(node_num, dim_in, dim_out, cheb_k, embed_dim) - if i == 0 - else DDGCRNCell(node_num, dim_out, dim_out, cheb_k, embed_dim) - for i in range(num_layers) - ] - ) - - def forward(self, x, init_state, node_embeddings): - """ - Forward pass of the DGCRM model. - - Parameters: - - x: Input tensor of shape (B, T, N, D) - - init_state: Initial hidden states of shape (num_layers, B, N, hidden_dim) - - node_embeddings: Node embeddings - """ - assert x.shape[2] == self.node_num and x.shape[3] == self.input_dim - - seq_length = x.shape[1] - current_inputs = x - output_hidden = [] - - for i in range(self.num_layers): - state = init_state[i] - inner_states = [] - - for t in range(seq_length): - state = self.DGCRM_cells[i]( - current_inputs[:, t, :, :], - state, - [node_embeddings[0][:, t, :, :], node_embeddings[1]], - ) - inner_states.append(state) - - output_hidden.append(state) - current_inputs = torch.stack(inner_states, dim=1) - - return current_inputs, output_hidden - - def init_hidden(self, batch_size): - """ - Initialize hidden states for DGCRM layers. - - Parameters: - - batch_size: Size of the batch - - Returns: - - Initial hidden states tensor - """ - return torch.stack( - [ - self.DGCRM_cells[i].init_hidden_state(batch_size) - for i in range(self.num_layers) - ], - dim=0, - ) - - -class DDGCRN(nn.Module): - def __init__(self, args): - super(DDGCRN, self).__init__() - - self.num_node = args["num_nodes"] - self.input_dim = args["input_dim"] - self.hidden_dim = args["rnn_units"] - self.output_dim = args["output_dim"] - self.horizon = args["horizon"] - self.num_layers = args["num_layers"] - self.use_day = args["use_day"] - self.use_week = args["use_week"] - self.default_graph = args["default_graph"] - - self.node_embeddings1 = nn.Parameter( - torch.randn(self.num_node, args["embed_dim"]), requires_grad=True - ) - self.node_embeddings2 = nn.Parameter( - torch.randn(self.num_node, args["embed_dim"]), requires_grad=True - ) - self.T_i_D_emb = nn.Parameter(torch.empty(288, args["embed_dim"])) - self.D_i_W_emb = nn.Parameter(torch.empty(7, args["embed_dim"])) - - self.dropout1 = nn.Dropout(p=0.1) - self.dropout2 = nn.Dropout(p=0.1) - - self.encoder1 = DGCRM( - self.num_node, - self.input_dim, - self.hidden_dim, - args["cheb_order"], - args["embed_dim"], - self.num_layers, - ) - self.encoder2 = DGCRM( - self.num_node, - self.input_dim, - self.hidden_dim, - args["cheb_order"], - args["embed_dim"], - self.num_layers, - ) - - # Predictor - self.end_conv1 = nn.Conv2d( - 1, - self.horizon * self.output_dim, - kernel_size=(1, self.hidden_dim), - bias=True, - ) - self.end_conv2 = nn.Conv2d( - 1, - self.horizon * self.output_dim, - kernel_size=(1, self.hidden_dim), - bias=True, - ) - self.end_conv3 = nn.Conv2d( - 1, - self.horizon * self.output_dim, - kernel_size=(1, self.hidden_dim), - bias=True, - ) - - def forward(self, source, **kwargs): - """ - Forward pass of the DDGCRN model. - - Parameters: - - source: Input tensor of shape (B, T_1, N, D) - - mode: Control mode for the forward pass - - Returns: - - Output tensor - """ - node_embedding1 = self.node_embeddings1 - - if self.use_day: - t_i_d_data = source[..., 1] - T_i_D_emb = self.T_i_D_emb[(t_i_d_data * 288).long()] - node_embedding1 = node_embedding1 * T_i_D_emb - - if self.use_week: - d_i_w_data = source[..., 2] - D_i_W_emb = self.D_i_W_emb[d_i_w_data.long()] - node_embedding1 = node_embedding1 * D_i_W_emb - - node_embeddings = [node_embedding1, self.node_embeddings1] - source = source[..., 0].unsqueeze(-1) - - init_state1 = self.encoder1.init_hidden(source.shape[0]) - output, _ = self.encoder1(source, init_state1, node_embeddings) - output = self.dropout1(output[:, -1:, :, :]) - output1 = self.end_conv1(output) - - source1 = self.end_conv2(output) - source2 = source[:, -self.horizon :, ...] - source1 - - init_state2 = self.encoder2.init_hidden(source2.shape[0]) - output2, _ = self.encoder2(source2, init_state2, node_embeddings) - output2 = self.dropout2(output2[:, -1:, :, :]) - output2 = self.end_conv3(output2) - - return output1 + output2 - - -class DDGCRNCell(nn.Module): - def __init__(self, node_num, dim_in, dim_out, cheb_k, embed_dim): - super(DDGCRNCell, self).__init__() - self.node_num = node_num - self.hidden_dim = dim_out - self.gate = DGCN(dim_in + self.hidden_dim, 2 * dim_out, cheb_k, embed_dim) - self.update = DGCN(dim_in + self.hidden_dim, dim_out, cheb_k, embed_dim) - - def forward(self, x, state, node_embeddings): - state = state.to(x.device) - input_and_state = torch.cat((x, state), dim=-1) - z_r = torch.sigmoid(self.gate(input_and_state, node_embeddings)) - z, r = torch.split(z_r, self.hidden_dim, dim=-1) - candidate = torch.cat((x, z * state), dim=-1) - hc = torch.tanh(self.update(candidate, node_embeddings)) - h = r * state + (1 - r) * hc - return h - - def init_hidden_state(self, batch_size): - return torch.zeros(batch_size, self.node_num, self.hidden_dim) - - -class DGCN(nn.Module): - def __init__(self, dim_in, dim_out, cheb_k, embed_dim): - super(DGCN, self).__init__() - self.cheb_k = cheb_k - self.embed_dim = embed_dim - - # Initialize parameters - self.weights_pool = nn.Parameter( - torch.FloatTensor(embed_dim, cheb_k, dim_in, dim_out) - ) - self.weights = nn.Parameter(torch.FloatTensor(cheb_k, dim_in, dim_out)) - self.bias_pool = nn.Parameter(torch.FloatTensor(embed_dim, dim_out)) - self.bias = nn.Parameter(torch.FloatTensor(dim_out)) - - # Hyperparameters - self.hyperGNN_dim = 16 - self.middle_dim = 2 - - # Fully connected layers - self.fc = nn.Sequential( - OrderedDict( - [ - ("fc1", nn.Linear(dim_in, self.hyperGNN_dim)), - ("sigmoid1", nn.Sigmoid()), - ("fc2", nn.Linear(self.hyperGNN_dim, self.middle_dim)), - ("sigmoid2", nn.Sigmoid()), - ("fc3", nn.Linear(self.middle_dim, self.embed_dim)), - ] - ) - ) - - def forward(self, x, node_embeddings): - """ - Forward pass for the DGCN model. - - Parameters: - - x: Input tensor of shape [B, N, C] - - node_embeddings: Node embeddings tensor of shape [N, D] - - connMtx: Connectivity matrix - - Returns: - - x_gconv: Output tensor of shape [B, N, dim_out] - """ - - node_num = node_embeddings[0].shape[1] - supports1 = torch.eye(node_num).to(node_embeddings[0].device) # Identity matrix - - # Apply fully connected layers - filter = self.fc(x) - nodevec = torch.tanh( - torch.mul(node_embeddings[0], filter) - ) # Element-wise multiplication - - # Compute Laplacian - supports2 = self.get_laplacian( - F.relu(torch.matmul(nodevec, nodevec.transpose(2, 1))), supports1 - ) - - # Graph convolution - x_g1 = torch.einsum("nm,bmc->bnc", supports1, x) - x_g2 = torch.einsum("bnm,bmc->bnc", supports2, x) - x_g = torch.stack([x_g1, x_g2], dim=1) - - # Apply graph convolution weights and biases - weights = torch.einsum("nd,dkio->nkio", node_embeddings[1], self.weights_pool) - bias = torch.matmul(node_embeddings[1], self.bias_pool) - - x_g = x_g.permute(0, 2, 1, 3) # Rearrange dimensions - x_gconv = ( - torch.einsum("bnki,nkio->bno", x_g, weights) + bias - ) # Graph convolution operation - - return x_gconv - - @staticmethod - def get_laplacian(graph, I, normalize=True): - """ - Compute the Laplacian of the graph. - - Parameters: - - graph: Adjacency matrix of the graph, [N, N] - - I: Identity matrix - - normalize: Whether to use the normalized Laplacian - - Returns: - - L: Graph Laplacian - """ - if normalize: - D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) - L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) - else: - graph = graph + I - D_inv_sqrt = torch.diag_embed(torch.sum(graph, dim=-1) ** (-1 / 2)) - L = torch.matmul(torch.matmul(D_inv_sqrt, graph), D_inv_sqrt) - return L diff --git a/model/STID/STID.py b/model/STID/STID.py index 596fefb..713bc02 100755 --- a/model/STID/STID.py +++ b/model/STID/STID.py @@ -1,19 +1,11 @@ import torch from torch import nn - from model.STID.MLP import MultiLayerPerceptron class STID(nn.Module): - """ - Paper: Spatial-Temporal Identity: A Simple yet Effective Baseline for Multivariate Time Series Forecasting - Link: https://arxiv.org/abs/2208.05233 - Official Code: https://github.com/zezhishao/STID - """ - def __init__(self, model_args): super().__init__() - # attributes self.num_nodes = model_args["num_nodes"] self.node_dim = model_args["node_dim"] self.input_len = model_args["input_len"] @@ -30,23 +22,22 @@ class STID(nn.Module): self.if_day_in_week = model_args["if_D_i_W"] self.if_spatial = model_args["if_node"] - # spatial embeddings if self.if_spatial: self.node_emb = nn.Parameter(torch.empty(self.num_nodes, self.node_dim)) nn.init.xavier_uniform_(self.node_emb) - # temporal embeddings + if self.if_time_in_day: self.time_in_day_emb = nn.Parameter( torch.empty(self.time_of_day_size, self.temp_dim_tid) ) nn.init.xavier_uniform_(self.time_in_day_emb) + if self.if_day_in_week: self.day_in_week_emb = nn.Parameter( torch.empty(self.day_of_week_size, self.temp_dim_diw) ) nn.init.xavier_uniform_(self.day_in_week_emb) - # embedding layer self.time_series_emb_layer = nn.Conv2d( in_channels=self.input_dim * self.input_len, out_channels=self.embed_dim, @@ -54,21 +45,16 @@ class STID(nn.Module): bias=True, ) - # encoding self.hidden_dim = ( self.embed_dim + self.node_dim * int(self.if_spatial) - + self.temp_dim_tid * int(self.if_day_in_week) - + self.temp_dim_diw * int(self.if_time_in_day) + + self.temp_dim_tid * int(self.if_time_in_day) + + self.temp_dim_diw * int(self.if_day_in_week) ) self.encoder = nn.Sequential( - *[ - MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) - for _ in range(self.num_layer) - ] + *[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)] ) - # regression self.regression_layer = nn.Conv2d( in_channels=self.hidden_dim, out_channels=self.output_len, @@ -77,67 +63,47 @@ class STID(nn.Module): ) def forward(self, history_data: torch.Tensor) -> torch.Tensor: - """Feed forward of STID. - - Args: - history_data (torch.Tensor): history data with shape [B, L, N, C] - - Returns: - torch.Tensor: prediction with shape [B, L, N, C] - """ - - # prepare data + device = history_data.device input_data = history_data[..., range(self.input_dim)] - # input_data = history_data[..., 0:1] if self.if_time_in_day: t_i_d_data = history_data[..., 1] - # In the datasets used in STID, the time_of_day feature is normalized to [0, 1]. We multiply it by 288 to get the index. - # If you use other datasets, you may need to change this line. - time_in_day_emb = self.time_in_day_emb[ - (t_i_d_data[:, -1, :] * self.time_of_day_size).type(torch.LongTensor) - ] + idx_tid = (t_i_d_data[:, -1, :] * self.time_of_day_size).long() + idx_tid = torch.clamp(idx_tid, 0, self.time_of_day_size - 1).to(device) + time_in_day_emb = self.time_in_day_emb[idx_tid] else: time_in_day_emb = None + if self.if_day_in_week: d_i_w_data = history_data[..., 2] - day_in_week_emb = self.day_in_week_emb[ - (d_i_w_data[:, -1, :] * self.day_of_week_size).type(torch.LongTensor) - ] + idx_diw = (d_i_w_data[:, -1, :] * self.day_of_week_size).long() + idx_diw = torch.clamp(idx_diw, 0, self.day_of_week_size - 1).to(device) + day_in_week_emb = self.day_in_week_emb[idx_diw] else: day_in_week_emb = None - # time series embedding - batch_size, _, num_nodes, _ = input_data.shape - input_data = input_data.transpose(1, 2).contiguous() - input_data = ( - input_data.view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1) - ) - time_series_emb = self.time_series_emb_layer(input_data) + B, L, N, C = input_data.shape + x = input_data.permute(0, 3, 1, 2).reshape(B, L * C, 1, N) + x = x.to(device) + time_series_emb = self.time_series_emb_layer(x) # [B, E, 1, N] node_emb = [] if self.if_spatial: - # expand node embeddings node_emb.append( self.node_emb.unsqueeze(0) - .expand(batch_size, -1, -1) - .transpose(1, 2) - .unsqueeze(-1) + .expand(B, -1, -1) + .transpose(1, 2) + .unsqueeze(2) # ✅ [B, Dn, 1, N] ) - # temporal embeddings + tem_emb = [] if time_in_day_emb is not None: - tem_emb.append(time_in_day_emb.transpose(1, 2).unsqueeze(-1)) + tem_emb.append(time_in_day_emb.transpose(1, 2).unsqueeze(2)) # [B, Dt, 1, N] if day_in_week_emb is not None: - tem_emb.append(day_in_week_emb.transpose(1, 2).unsqueeze(-1)) + tem_emb.append(day_in_week_emb.transpose(1, 2).unsqueeze(2)) # [B, Dw, 1, N] - # concate all embeddings hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1) - - # encoding hidden = self.encoder(hidden) - - # regression prediction = self.regression_layer(hidden) - - return prediction + prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c] + return prediction # [B, t, n, c] diff --git a/trainer/Trainer.py b/trainer/Trainer.py index aa2faf3..6c62ee3 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -145,6 +145,8 @@ class Trainer: total_loss = 0 epoch_time = time.time() + y_pred, y_true = [], [] + with torch.set_grad_enabled(optimizer_step): for batch_idx, (data, target) in enumerate(dataloader): start_time = time.time() @@ -170,11 +172,16 @@ class Trainer: self.stats.record_step_time(step_time, mode) total_loss += loss.item() + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) + + y_pred = torch.cat(y_pred, dim=0) + y_true = torch.cat(y_true, dim=0) avg_loss = total_loss / len(dataloader) # 输出指标 mae, rmse, mape = all_metrics( - output, label, self.args["mae_thresh"], self.args["mape_thresh"] + y_pred, y_true, self.args["mae_thresh"], self.args["mape_thresh"] ) self.logger.info( f"Epoch #{epoch:02d}: {mode.capitalize():<5} MAE:{mae:5.2f} | RMSE:{rmse:5.2f} | MAPE:{mape:7.4f} | Time: {time.time() - epoch_time:.2f} s"