diff --git a/.vscode/launch.json b/.vscode/launch.json index 54aad8a..fb16dc0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -6,2113 +6,12 @@ "configurations": [ { - "name": "DDGCRN: METR-LA", + "name": "train", "type": "debugpy", "request": "launch", - "program": "run.py", + "program": "train.py", "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/METR-LA.yaml" - }, - // STID 模型组 - { - "name": "STID: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/PEMS-BAY.yaml" - }, - { - "name": "STID: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/METR-LA.yaml" - }, - { - "name": "STID: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/PEMSD4.yaml" - }, - { - "name": "STID: BJTaxi-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/BJTaxi_Inflow.yaml" - }, - { - "name": "STID: BJTaxi-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/BJTaxi_Outflow.yaml" - }, - { - "name": "STID: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Inflow.yaml" - }, - { - "name": "STID: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Outflow.yaml" - }, - { - "name": "STID: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/SolarEnergy.yaml" - }, - - // REPST 模型组 - { - "name": "REPST: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMSD8.yaml" - }, - { - "name": "REPST: BJTaxi-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" - }, - { - "name": "REPST: NYCBike-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-outflow.yaml" - }, - { - "name": "REPST: NYCBike-inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-inflow.yaml" - }, - { - "name": "REPST: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMS-BAY.yaml" - }, - { - "name": "REPST: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/METR-LA.yaml" - }, - { - "name": "REPST: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/SolarEnergy.yaml" - }, - { - "name": "REPST: BeijingAirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BeijingAirQuality.yaml" - }, - { - "name": "REPST: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/AirQuality.yaml" - }, - - // ASTRA 模型组 - { - "name": "ASTRA: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/PEMS-BAY.yaml" - }, - { - "name": "ASTRA: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/METR-LA.yaml" - }, - { - "name": "ASTRA: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/AirQuality.yaml" - }, - { - "name": "ASTRA: BJTaxi-Inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/BJTaxi-Inflow.yaml" - }, - { - "name": "ASTRA: BJTaxi-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/BJTaxi-outflow.yaml" - }, - { - "name": "ASTRA: NYCBike-inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-inflow.yaml" - }, - { - "name": "ASTRA: NYCBike-outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-outflow.yaml" - }, - { - "name": "ASTRA: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/SolarEnergy.yaml" - }, - { - "name": "ASTRA_v2: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_AirQuality.yaml" - }, - { - "name": "ASTRA_v2: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v2_SolarEnergy.yaml" - }, - { - "name": "ASTRA_v3: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/v3_METR-LA.yaml" - }, - { - "name": "EXPB: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/NYCBike-InFlow.yaml" - }, - { - "name": "EXPB: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/PEMSD4.yaml" - }, - { - "name": "EXPB: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/METR-LA.yaml" - }, - { - "name": "EXPB: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/AirQuality.yaml" - }, - { - "name": "EXPB: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/NYCBike-OutFlow.yaml" - }, - { - "name": "EXPB: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXPB/SolarEnergy.yaml" - }, - { - "name": "TWDGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/NYCBike-InFlow.yaml" - }, - { - "name": "TWDGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD4.yaml" - }, - { - "name": "TWDGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/METR-LA.yaml" - }, - { - "name": "TWDGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/AirQuality.yaml" - }, - { - "name": "TWDGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "TWDGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD8.yaml" - }, - { - "name": "TWDGCN: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7(L).yaml" - }, - { - "name": "TWDGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD3.yaml" - }, - { - "name": "TWDGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/SolarEnergy.yaml" - }, - { - "name": "TWDGCN: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/Hainan.yaml" - }, - { - "name": "TWDGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7.yaml" - }, - { - "name": "TWDGCN: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TWDGCN/PEMSD7(M).yaml" - }, - { - "name": "STSGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STSGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD4.yaml" - }, - { - "name": "STSGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/METR-LA.yaml" - }, - { - "name": "STSGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/AirQuality.yaml" - }, - { - "name": "STSGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STSGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD8.yaml" - }, - { - "name": "STSGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD3.yaml" - }, - { - "name": "STSGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/SolarEnergy.yaml" - }, - { - "name": "STSGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STSGCN/PEMSD7.yaml" - }, - { - "name": "STID: NYCBike_Inflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Inflow.yaml" - }, - { - "name": "STID: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/AirQuality.yaml" - }, - { - "name": "STID: NYCBike_Outflow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STID/NYCBike_Outflow.yaml" - }, - { - "name": "STAWnet: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/NYCBike-InFlow.yaml" - }, - { - "name": "STAWnet: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD4.yaml" - }, - { - "name": "STAWnet: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/METR-LA.yaml" - }, - { - "name": "STAWnet: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/AirQuality.yaml" - }, - { - "name": "STAWnet: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/NYCBike-OutFlow.yaml" - }, - { - "name": "STAWnet: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD8.yaml" - }, - { - "name": "STAWnet: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD3.yaml" - }, - { - "name": "STAWnet: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/SolarEnergy.yaml" - }, - { - "name": "STAWnet: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAWnet/PEMSD7.yaml" - }, - { - "name": "DCRNN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/NYCBike-InFlow.yaml" - }, - { - "name": "DCRNN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD4.yaml" - }, - { - "name": "DCRNN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/METR-LA.yaml" - }, - { - "name": "DCRNN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/AirQuality.yaml" - }, - { - "name": "DCRNN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/NYCBike-OutFlow.yaml" - }, - { - "name": "DCRNN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD8.yaml" - }, - { - "name": "DCRNN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD3.yaml" - }, - { - "name": "DCRNN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/SolarEnergy.yaml" - }, - { - "name": "DCRNN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DCRNN/PEMSD7.yaml" - }, - { - "name": "STAEFormer: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/NYCBike-InFlow.yaml" - }, - { - "name": "STAEFormer: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD4.yaml" - }, - { - "name": "STAEFormer: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/METR-LA.yaml" - }, - { - "name": "STAEFormer: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/AirQuality.yaml" - }, - { - "name": "STAEFormer: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/NYCBike-OutFlow.yaml" - }, - { - "name": "STAEFormer: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD8.yaml" - }, - { - "name": "STAEFormer: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD3.yaml" - }, - { - "name": "STAEFormer: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/SolarEnergy.yaml" - }, - { - "name": "STAEFormer: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STAEFormer/PEMSD7.yaml" - }, - { - "name": "STGODE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/NYCBike-InFlow.yaml" - }, - { - "name": "STGODE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD4.yaml" - }, - { - "name": "STGODE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/METR-LA.yaml" - }, - { - "name": "STGODE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/AirQuality.yaml" - }, - { - "name": "STGODE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGODE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD8.yaml" - }, - { - "name": "STGODE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD3.yaml" - }, - { - "name": "STGODE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/SolarEnergy.yaml" - }, - { - "name": "STGODE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGODE/PEMSD7.yaml" - }, - { - "name": "STGNCDE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/NYCBike-InFlow.yaml" - }, - { - "name": "STGNCDE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD4.yaml" - }, - { - "name": "STGNCDE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/METR-LA.yaml" - }, - { - "name": "STGNCDE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/AirQuality.yaml" - }, - { - "name": "STGNCDE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGNCDE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD8.yaml" - }, - { - "name": "STGNCDE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD3.yaml" - }, - { - "name": "STGNCDE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/SolarEnergy.yaml" - }, - { - "name": "STGNCDE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNCDE/PEMSD7.yaml" - }, - { - "name": "ASTRA: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-InFlow.yaml" - }, - { - "name": "ASTRA: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ASTRA/NYCBike-OutFlow.yaml" - }, - { - "name": "ST_SSL: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/NYCBike-InFlow.yaml" - }, - { - "name": "ST_SSL: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD4.yaml" - }, - { - "name": "ST_SSL: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/METR-LA.yaml" - }, - { - "name": "ST_SSL: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/AirQuality.yaml" - }, - { - "name": "ST_SSL: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/NYCBike-OutFlow.yaml" - }, - { - "name": "ST_SSL: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD8.yaml" - }, - { - "name": "ST_SSL: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD3.yaml" - }, - { - "name": "ST_SSL: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/SolarEnergy.yaml" - }, - { - "name": "ST_SSL: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ST_SSL/PEMSD7.yaml" - }, - { - "name": "TCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/NYCBike-InFlow.yaml" - }, - { - "name": "TCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD4.yaml" - }, - { - "name": "TCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/METR-LA.yaml" - }, - { - "name": "TCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/AirQuality.yaml" - }, - { - "name": "TCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/NYCBike-OutFlow.yaml" - }, - { - "name": "TCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD8.yaml" - }, - { - "name": "TCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD3.yaml" - }, - { - "name": "TCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/SolarEnergy.yaml" - }, - { - "name": "TCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/TCN/PEMSD7.yaml" - }, - { - "name": "EXP: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/NYCBike-InFlow.yaml" - }, - { - "name": "EXP: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD4.yaml" - }, - { - "name": "EXP: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/METR-LA.yaml" - }, - { - "name": "EXP: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/AirQuality.yaml" - }, - { - "name": "EXP: SD", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/SD.yaml" - }, - { - "name": "EXP: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/NYCBike-OutFlow.yaml" - }, - { - "name": "EXP: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD8.yaml" - }, - { - "name": "EXP: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD3.yaml" - }, - { - "name": "EXP: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/SolarEnergy.yaml" - }, - { - "name": "EXP: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/EXP/PEMSD7.yaml" - }, - { - "name": "DDGCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/NYCBike-InFlow.yaml" - }, - { - "name": "DDGCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD4.yaml" - }, - { - "name": "DDGCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/AirQuality.yaml" - }, - { - "name": "DDGCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "DDGCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD8.yaml" - }, - { - "name": "DDGCRN: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7(L).yaml" - }, - { - "name": "DDGCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD3.yaml" - }, - { - "name": "DDGCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/SolarEnergy.yaml" - }, - { - "name": "DDGCRN: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/Hainan.yaml" - }, - { - "name": "DDGCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7.yaml" - }, - { - "name": "DDGCRN: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DDGCRN/PEMSD7(M).yaml" - }, - { - "name": "DSANET: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/NYCBike-InFlow.yaml" - }, - { - "name": "DSANET: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD4.yaml" - }, - { - "name": "DSANET: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/METR-LA.yaml" - }, - { - "name": "DSANET: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/AirQuality.yaml" - }, - { - "name": "DSANET: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/NYCBike-OutFlow.yaml" - }, - { - "name": "DSANET: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD8.yaml" - }, - { - "name": "DSANET: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD3.yaml" - }, - { - "name": "DSANET: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/SolarEnergy.yaml" - }, - { - "name": "DSANET: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/DSANET/PEMSD7.yaml" - }, - { - "name": "STFGNN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/NYCBike-InFlow.yaml" - }, - { - "name": "STFGNN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD4.yaml" - }, - { - "name": "STFGNN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/METR-LA.yaml" - }, - { - "name": "STFGNN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/AirQuality.yaml" - }, - { - "name": "STFGNN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/NYCBike-OutFlow.yaml" - }, - { - "name": "STFGNN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD8.yaml" - }, - { - "name": "STFGNN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD3.yaml" - }, - { - "name": "STFGNN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/SolarEnergy.yaml" - }, - { - "name": "STFGNN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STFGNN/PEMSD7.yaml" - }, - { - "name": "AGCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/NYCBike-InFlow.yaml" - }, - { - "name": "AGCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD4.yaml" - }, - { - "name": "AGCRN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/METR-LA.yaml" - }, - { - "name": "AGCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/AirQuality.yaml" - }, - { - "name": "AGCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "AGCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD8.yaml" - }, - { - "name": "AGCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD3.yaml" - }, - { - "name": "AGCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/SolarEnergy.yaml" - }, - { - "name": "AGCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/AGCRN/PEMSD7.yaml" - }, - { - "name": "STGNRDE: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/NYCBike-InFlow.yaml" - }, - { - "name": "STGNRDE: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD4.yaml" - }, - { - "name": "STGNRDE: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/METR-LA.yaml" - }, - { - "name": "STGNRDE: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/AirQuality.yaml" - }, - { - "name": "STGNRDE: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/NYCBike-OutFlow.yaml" - }, - { - "name": "STGNRDE: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD8.yaml" - }, - { - "name": "STGNRDE: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD3.yaml" - }, - { - "name": "STGNRDE: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/SolarEnergy.yaml" - }, - { - "name": "STGNRDE: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGNRDE/PEMSD7.yaml" - }, - { - "name": "REPST: PEMS-BAY_paper", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/PEMS-BAY_paper.yaml" - }, - { - "name": "REPST: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-InFlow.yaml" - }, - { - "name": "REPST: BeijingAirQuality(Deprecated)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/BeijingAirQuality(Deprecated).yaml" - }, - { - "name": "REPST: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/REPST/NYCBike-OutFlow.yaml" - }, - { - "name": "STIDGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STIDGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD4.yaml" - }, - { - "name": "STIDGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/METR-LA.yaml" - }, - { - "name": "STIDGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/AirQuality.yaml" - }, - { - "name": "STIDGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STIDGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD8.yaml" - }, - { - "name": "STIDGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD3.yaml" - }, - { - "name": "STIDGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/SolarEnergy.yaml" - }, - { - "name": "STIDGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STIDGCN/PEMSD7.yaml" - }, - { - "name": "PDG2SEQ: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/NYCBike-InFlow.yaml" - }, - { - "name": "PDG2SEQ: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD4.yaml" - }, - { - "name": "PDG2SEQ: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/METR-LA.yaml" - }, - { - "name": "PDG2SEQ: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/AirQuality.yaml" - }, - { - "name": "PDG2SEQ: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/NYCBike-OutFlow.yaml" - }, - { - "name": "PDG2SEQ: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD8.yaml" - }, - { - "name": "PDG2SEQ: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD3.yaml" - }, - { - "name": "PDG2SEQ: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/SolarEnergy.yaml" - }, - { - "name": "PDG2SEQ: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/PDG2SEQ/PEMSD7.yaml" - }, - { - "name": "NLT: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/NYCBike-InFlow.yaml" - }, - { - "name": "NLT: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD4.yaml" - }, - { - "name": "NLT: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/METR-LA.yaml" - }, - { - "name": "NLT: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/AirQuality.yaml" - }, - { - "name": "NLT: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/NYCBike-OutFlow.yaml" - }, - { - "name": "NLT: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD8.yaml" - }, - { - "name": "NLT: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD3.yaml" - }, - { - "name": "NLT: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/SolarEnergy.yaml" - }, - { - "name": "NLT: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/NLT/PEMSD7.yaml" - }, - { - "name": "ARIMA: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/NYCBike-InFlow.yaml" - }, - { - "name": "ARIMA: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD4.yaml" - }, - { - "name": "ARIMA: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/METR-LA.yaml" - }, - { - "name": "ARIMA: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/AirQuality.yaml" - }, - { - "name": "ARIMA: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/NYCBike-OutFlow.yaml" - }, - { - "name": "ARIMA: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD8.yaml" - }, - { - "name": "ARIMA: PEMSD7(L)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7(L).yaml" - }, - { - "name": "ARIMA: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD3.yaml" - }, - { - "name": "ARIMA: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/SolarEnergy.yaml" - }, - { - "name": "ARIMA: Hainan", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/Hainan.yaml" - }, - { - "name": "ARIMA: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7.yaml" - }, - { - "name": "ARIMA: PEMSD7(M)", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/ARIMA/PEMSD7(M).yaml" - }, - { - "name": "STMLP: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/NYCBike-InFlow.yaml" - }, - { - "name": "STMLP: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD4.yaml" - }, - { - "name": "STMLP: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/METR-LA.yaml" - }, - { - "name": "STMLP: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/AirQuality.yaml" - }, - { - "name": "STMLP: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/NYCBike-OutFlow.yaml" - }, - { - "name": "STMLP: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD8.yaml" - }, - { - "name": "STMLP: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD3.yaml" - }, - { - "name": "STMLP: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/SolarEnergy.yaml" - }, - { - "name": "STMLP: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STMLP/PEMSD7.yaml" - }, - { - "name": "MegaCRN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/NYCBike-InFlow.yaml" - }, - { - "name": "MegaCRN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD4.yaml" - }, - { - "name": "MegaCRN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/METR-LA.yaml" - }, - { - "name": "MegaCRN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/AirQuality.yaml" - }, - { - "name": "MegaCRN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/NYCBike-OutFlow.yaml" - }, - { - "name": "MegaCRN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD8.yaml" - }, - { - "name": "MegaCRN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD3.yaml" - }, - { - "name": "MegaCRN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/SolarEnergy.yaml" - }, - { - "name": "MegaCRN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/MegaCRN/PEMSD7.yaml" - }, - { - "name": "GWN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/NYCBike-InFlow.yaml" - }, - { - "name": "GWN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD4.yaml" - }, - { - "name": "GWN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/METR-LA.yaml" - }, - { - "name": "GWN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/AirQuality.yaml" - }, - { - "name": "GWN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/NYCBike-OutFlow.yaml" - }, - { - "name": "GWN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD8.yaml" - }, - { - "name": "GWN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD3.yaml" - }, - { - "name": "GWN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/SolarEnergy.yaml" - }, - { - "name": "GWN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/GWN/PEMSD7.yaml" - }, - { - "name": "STGCN: NYCBike-InFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/NYCBike-InFlow.yaml" - }, - { - "name": "STGCN: PEMSD4", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD4.yaml" - }, - { - "name": "STGCN: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/METR-LA.yaml" - }, - { - "name": "STGCN: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/AirQuality.yaml" - }, - { - "name": "STGCN: NYCBike-OutFlow", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/NYCBike-OutFlow.yaml" - }, - { - "name": "STGCN: PEMSD8", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD8.yaml" - }, - { - "name": "STGCN: PEMSD3", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD3.yaml" - }, - { - "name": "STGCN: SolarEnergy", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/SolarEnergy.yaml" - }, - { - "name": "STGCN: PEMSD7", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/STGCN/PEMSD7.yaml" - }, - { - "name": "iTransformer: METR-LA", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/iTransformer/METR-LA.yaml" - }, - { - "name": "iTransformer: AirQuality", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/iTransformer/AirQuality.yaml" - }, - { - "name": "HI: PEMS-BAY", - "type": "debugpy", - "request": "launch", - "program": "run.py", - "console": "integratedTerminal", - "args": "--config ./config/HI/PEMS-BAY.yaml" - }, + "justMyCode": false + } ] } \ No newline at end of file diff --git a/config/MTGNN/AirQuality.yaml b/config/MTGNN/AirQuality.yaml new file mode 100644 index 0000000..9846895 --- /dev/null +++ b/config/MTGNN/AirQuality.yaml @@ -0,0 +1,64 @@ +basic: + dataset: AirQuality + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 35 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 6 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 6 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 6 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/BJTaxi-Inflow.yaml b/config/MTGNN/BJTaxi-Inflow.yaml new file mode 100644 index 0000000..09e453a --- /dev/null +++ b/config/MTGNN/BJTaxi-Inflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/BJTaxi-Outflow.yaml b/config/MTGNN/BJTaxi-Outflow.yaml new file mode 100644 index 0000000..1b62a4e --- /dev/null +++ b/config/MTGNN/BJTaxi-Outflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: BJTaxi-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 1024 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/METR-LA.yaml b/config/MTGNN/METR-LA.yaml new file mode 100644 index 0000000..2518638 --- /dev/null +++ b/config/MTGNN/METR-LA.yaml @@ -0,0 +1,64 @@ +basic: + dataset: METR-LA + device: cuda:1 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 207 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/NYCBike-Inflow.yaml b/config/MTGNN/NYCBike-Inflow.yaml new file mode 100644 index 0000000..95ae41b --- /dev/null +++ b/config/MTGNN/NYCBike-Inflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-InFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/NYCBike-Outflow.yaml b/config/MTGNN/NYCBike-Outflow.yaml new file mode 100644 index 0000000..b1646ea --- /dev/null +++ b/config/MTGNN/NYCBike-Outflow.yaml @@ -0,0 +1,64 @@ +basic: + dataset: NYCBike-OutFlow + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 128 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/PEMS-BAY.yaml b/config/MTGNN/PEMS-BAY.yaml new file mode 100644 index 0000000..7f28aca --- /dev/null +++ b/config/MTGNN/PEMS-BAY.yaml @@ -0,0 +1,64 @@ +basic: + dataset: PEMS-BAY + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 325 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/MTGNN/SolarEnergy.yaml b/config/MTGNN/SolarEnergy.yaml new file mode 100644 index 0000000..2f60b8d --- /dev/null +++ b/config/MTGNN/SolarEnergy.yaml @@ -0,0 +1,64 @@ +basic: + dataset: SolarEnergy + device: cuda:0 + mode: train + model: MTGNN + seed: 2023 + +data: + batch_size: 64 + column_wise: false + days_per_week: 7 + horizon: 24 + input_dim: 6 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 24 + test_ratio: 0.2 + val_ratio: 0.2 + +model: + gcn_true: True # 是否使用图卷积网络 (bool) + buildA_true: True # 是否动态构建邻接矩阵 (bool) + subgraph_size: 20 # 子图大小 (int) + num_nodes: 137 # 节点数量 (int) + node_dim: 40 # 节点嵌入维度 (int) + dilation_exponential: 1 # 膨胀卷积指数 (int) + conv_channels: 32 # 卷积通道数 (int) + residual_channels: 32 # 残差通道数 (int) + skip_channels: 64 # 跳跃连接通道数 (int) + end_channels: 128 # 输出层通道数 (int) + seq_len: 24 # 输入序列长度 (int) + in_dim: 1 # 输入特征维度 (int) + out_len: 24 # 输出序列长度 (int) + out_dim: 1 # 输出预测维度 (int) + layers: 3 # 模型层数 (int) + propalpha: 0.05 # 图传播参数alpha (float) + tanhalpha: 3 # tanh激活参数alpha (float) + layer_norm_affline: True # 层归一化是否使用affine变换 (bool) + gcn_depth: 2 # 图卷积深度 (int) + dropout: 0.3 # dropout率 (float) + predefined_A: null # 预定义邻接矩阵 (optional, None) + static_feat: null # 静态特征 (optional, None) + +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_step: 1000 + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + plot: false + real_value: true + weight_decay: 0 \ No newline at end of file diff --git a/config/PatchTST/AirQuality.yaml b/config/PatchTST/AirQuality.yaml index a3e6418..3cdf977 100644 --- a/config/PatchTST/AirQuality.yaml +++ b/config/PatchTST/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 6 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/BJTaxi-Inflow.yaml b/config/PatchTST/BJTaxi-Inflow.yaml index 9bd66d9..576dbd6 100644 --- a/config/PatchTST/BJTaxi-Inflow.yaml +++ b/config/PatchTST/BJTaxi-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/BJTaxi-Outflow.yaml b/config/PatchTST/BJTaxi-Outflow.yaml index 2382695..773ba26 100644 --- a/config/PatchTST/BJTaxi-Outflow.yaml +++ b/config/PatchTST/BJTaxi-Outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/METR-LA.yaml b/config/PatchTST/METR-LA.yaml index d076d35..6b9461a 100644 --- a/config/PatchTST/METR-LA.yaml +++ b/config/PatchTST/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:1 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/NYCBike-Inflow.yaml b/config/PatchTST/NYCBike-Inflow.yaml index 2c3026c..408995c 100644 --- a/config/PatchTST/NYCBike-Inflow.yaml +++ b/config/PatchTST/NYCBike-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/NYCBike-Outflow.yaml b/config/PatchTST/NYCBike-Outflow.yaml index 16eee20..c50f4a1 100644 --- a/config/PatchTST/NYCBike-Outflow.yaml +++ b/config/PatchTST/NYCBike-Outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -23,6 +23,7 @@ model: seq_len: 24 pred_len: 24 patch_len: 6 + enc_in: 1 stride: 8 d_model: 128 d_ff: 2048 diff --git a/config/PatchTST/PEMS-BAY.yaml b/config/PatchTST/PEMS-BAY.yaml index 6186db3..e798294 100644 --- a/config/PatchTST/PEMS-BAY.yaml +++ b/config/PatchTST/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -24,6 +24,7 @@ model: pred_len: 24 d_model: 128 patch_len: 6 + enc_in: 1 stride: 8 d_ff: 2048 dropout: 0.1 diff --git a/config/PatchTST/SolarEnergy.yaml b/config/PatchTST/SolarEnergy.yaml index 28b85b9..b1de602 100644 --- a/config/PatchTST/SolarEnergy.yaml +++ b/config/PatchTST/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: iTransformer + model: PatchTST seed: 2023 data: @@ -24,6 +24,7 @@ model: pred_len: 24 d_model: 128 patch_len: 6 + enc_in: 6 stride: 8 d_ff: 2048 dropout: 0.1 diff --git a/model/MTGNN/MTGNN.py b/model/MTGNN/MTGNN.py index 483a184..43d9b31 100644 --- a/model/MTGNN/MTGNN.py +++ b/model/MTGNN/MTGNN.py @@ -3,91 +3,109 @@ from model.MTGNN.layer import * class gtnet(nn.Module): - def __init__(self, gcn_true, buildA_true, gcn_depth, num_nodes, device, predefined_A=None, static_feat=None, dropout=0.3, subgraph_size=20, node_dim=40, dilation_exponential=1, conv_channels=32, residual_channels=32, skip_channels=64, end_channels=128, seq_length=12, in_dim=2, out_dim=12, layers=3, propalpha=0.05, tanhalpha=3, layer_norm_affline=True): + def __init__(self, configs): super(gtnet, self).__init__() - self.gcn_true = gcn_true - self.buildA_true = buildA_true - self.num_nodes = num_nodes - self.dropout = dropout - self.predefined_A = predefined_A - self.filter_convs = nn.ModuleList() - self.gate_convs = nn.ModuleList() - self.residual_convs = nn.ModuleList() - self.skip_convs = nn.ModuleList() - self.gconv1 = nn.ModuleList() - self.gconv2 = nn.ModuleList() - self.norm = nn.ModuleList() - self.start_conv = nn.Conv2d(in_channels=in_dim, - out_channels=residual_channels, + self.gcn_true = configs['gcn_true'] # 是否使用图卷积网络 + self.buildA_true = configs['buildA_true'] # 是否动态构建邻接矩阵 + self.num_nodes = configs['num_nodes'] # 节点数量 + self.device = configs['device'] # 设备(CPU/GPU) + self.dropout = configs['dropout'] # dropout率 + self.predefined_A = configs.get('predefined_A', None) # 预定义邻接矩阵 + self.static_feat = configs.get('static_feat', None) # 静态特征 + self.subgraph_size = configs['subgraph_size'] # 子图大小 + self.node_dim = configs['node_dim'] # 节点嵌入维度 + self.dilation_exponential = configs['dilation_exponential'] # 膨胀卷积指数 + self.conv_channels = configs['conv_channels'] # 卷积通道数 + self.residual_channels = configs['residual_channels'] # 残差通道数 + self.skip_channels = configs['skip_channels'] # 跳跃连接通道数 + self.end_channels = configs['end_channels'] # 输出层通道数 + self.seq_length = configs['seq_len'] # 输入序列长度 + self.in_dim = configs['in_dim'] # 输入特征维度 + self.out_len = configs['out_len'] # 输出序列长度 + self.out_dim = configs['out_dim'] # 输出预测维度 + self.layers = configs['layers'] # 模型层数 + self.propalpha = configs['propalpha'] # 图传播参数alpha + self.tanhalpha = configs['tanhalpha'] # tanh激活参数alpha + self.layer_norm_affline = configs['layer_norm_affline'] # 层归一化是否使用affine变换 + self.gcn_depth = configs['gcn_depth'] # 图卷积深度 + self.filter_convs = nn.ModuleList() # 卷积滤波器列表 + self.gate_convs = nn.ModuleList() # 门控卷积列表 + self.residual_convs = nn.ModuleList() # 残差卷积列表 + self.skip_convs = nn.ModuleList() # 跳跃连接卷积列表 + self.gconv1 = nn.ModuleList() # 第一层图卷积列表 + self.gconv2 = nn.ModuleList() # 第二层图卷积列表 + self.norm = nn.ModuleList() # 归一化层列表 + self.start_conv = nn.Conv2d(in_channels=self.in_dim, + out_channels=self.residual_channels, kernel_size=(1, 1)) - self.gc = graph_constructor(num_nodes, subgraph_size, node_dim, device, alpha=tanhalpha, static_feat=static_feat) + self.gc = graph_constructor(self.num_nodes, self.subgraph_size, self.node_dim, self.device, alpha=self.tanhalpha, static_feat=self.static_feat) - self.seq_length = seq_length kernel_size = 7 - if dilation_exponential>1: - self.receptive_field = int(1+(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + if self.dilation_exponential>1: + self.receptive_field = int(1+(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: - self.receptive_field = layers*(kernel_size-1) + 1 + self.receptive_field = self.layers*(kernel_size-1) + 1 for i in range(1): - if dilation_exponential>1: - rf_size_i = int(1 + i*(kernel_size-1)*(dilation_exponential**layers-1)/(dilation_exponential-1)) + if self.dilation_exponential>1: + rf_size_i = int(1 + i*(kernel_size-1)*(self.dilation_exponential**self.layers-1)/(self.dilation_exponential-1)) else: - rf_size_i = i*layers*(kernel_size-1)+1 + rf_size_i = i*self.layers*(kernel_size-1)+1 new_dilation = 1 - for j in range(1,layers+1): - if dilation_exponential > 1: - rf_size_j = int(rf_size_i + (kernel_size-1)*(dilation_exponential**j-1)/(dilation_exponential-1)) + for j in range(1,self.layers+1): + if self.dilation_exponential > 1: + rf_size_j = int(rf_size_i + (kernel_size-1)*(self.dilation_exponential**j-1)/(self.dilation_exponential-1)) else: rf_size_j = rf_size_i+j*(kernel_size-1) - self.filter_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) - self.gate_convs.append(dilated_inception(residual_channels, conv_channels, dilation_factor=new_dilation)) - self.residual_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=residual_channels, + self.filter_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.gate_convs.append(dilated_inception(self.residual_channels, self.conv_channels, dilation_factor=new_dilation)) + self.residual_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.residual_channels, kernel_size=(1, 1))) if self.seq_length>self.receptive_field: - self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=skip_channels, + self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.skip_channels, kernel_size=(1, self.seq_length-rf_size_j+1))) else: - self.skip_convs.append(nn.Conv2d(in_channels=conv_channels, - out_channels=skip_channels, + self.skip_convs.append(nn.Conv2d(in_channels=self.conv_channels, + out_channels=self.skip_channels, kernel_size=(1, self.receptive_field-rf_size_j+1))) if self.gcn_true: - self.gconv1.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) - self.gconv2.append(mixprop(conv_channels, residual_channels, gcn_depth, dropout, propalpha)) + self.gconv1.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) + self.gconv2.append(mixprop(self.conv_channels, self.residual_channels, self.gcn_depth, self.dropout, self.propalpha)) if self.seq_length>self.receptive_field: - self.norm.append(LayerNorm((residual_channels, num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.seq_length - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) else: - self.norm.append(LayerNorm((residual_channels, num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=layer_norm_affline)) + self.norm.append(LayerNorm((self.residual_channels, self.num_nodes, self.receptive_field - rf_size_j + 1),elementwise_affine=self.layer_norm_affline)) - new_dilation *= dilation_exponential + new_dilation *= self.dilation_exponential - self.layers = layers - self.end_conv_1 = nn.Conv2d(in_channels=skip_channels, - out_channels=end_channels, + self.end_conv_1 = nn.Conv2d(in_channels=self.skip_channels, + out_channels=self.end_channels, kernel_size=(1,1), bias=True) - self.end_conv_2 = nn.Conv2d(in_channels=end_channels, - out_channels=out_dim, + self.end_conv_2 = nn.Conv2d(in_channels=self.end_channels, + out_channels=self.out_len * self.out_dim, kernel_size=(1,1), bias=True) if self.seq_length > self.receptive_field: - self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.seq_length), bias=True) - self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.seq_length), bias=True) + self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, self.seq_length-self.receptive_field+1), bias=True) + else: - self.skip0 = nn.Conv2d(in_channels=in_dim, out_channels=skip_channels, kernel_size=(1, self.receptive_field), bias=True) - self.skipE = nn.Conv2d(in_channels=residual_channels, out_channels=skip_channels, kernel_size=(1, 1), bias=True) + self.skip0 = nn.Conv2d(in_channels=self.in_dim, out_channels=self.skip_channels, kernel_size=(1, self.receptive_field), bias=True) + self.skipE = nn.Conv2d(in_channels=self.residual_channels, out_channels=self.skip_channels, kernel_size=(1, 1), bias=True) - - self.idx = torch.arange(self.num_nodes).to(device) + self.idx = torch.arange(self.num_nodes).to(self.device) def forward(self, input, idx=None): + input = input[..., :-2] # 去掉周期嵌入 + input = input.transpose(1, 3) seq_len = input.size(3) assert seq_len==self.seq_length, 'input sequence length not equal to preset sequence length' @@ -130,5 +148,8 @@ class gtnet(nn.Module): skip = self.skipE(x) + skip x = F.relu(skip) x = F.relu(self.end_conv_1(x)) - x = self.end_conv_2(x) + x = self.end_conv_2(x) # [b, t*c, n, 1] + # [b, t*c, n, 1] -> [b,t,c,n] -> [b, t, n, c] + x = x.reshape(x.size(0), self.out_len, self.out_dim, self.num_nodes) + x = x.permute(0, 1, 3, 2) return x \ No newline at end of file diff --git a/model/PatchTST/PatchTST.py b/model/PatchTST/PatchTST.py index 3112030..4645c28 100644 --- a/model/PatchTST/PatchTST.py +++ b/model/PatchTST/PatchTST.py @@ -62,14 +62,14 @@ class Model(nn.Module): activation=configs['activation'] ) for l in range(configs['e_layers']) ], - norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs.d_model), Transpose(1,2)) + norm_layer=nn.Sequential(Transpose(1,2), nn.BatchNorm1d(configs['d_model']), Transpose(1,2)) ) # Prediction Head - self.head_nf = configs.d_model * \ - int((configs.seq_len - self.patch_len) / self.stride + 2) - self.head = FlattenHead(configs.enc_in, self.head_nf, configs.pred_len, - head_dropout=configs.dropout) + self.head_nf = configs['d_model'] * \ + int((configs['seq_len'] - self.patch_len) / self.stride + 2) + self.head = FlattenHead(configs['enc_in'], self.head_nf, configs['pred_len'], + head_dropout=configs['dropout']) def forecast(self, x_enc): # Normalization from Non-stationary Transformer diff --git a/model/PatchTST/layers/Embed.py b/model/PatchTST/layers/Embed.py index 94896e0..d38d093 100644 --- a/model/PatchTST/layers/Embed.py +++ b/model/PatchTST/layers/Embed.py @@ -1,5 +1,26 @@ import torch import torch.nn as nn +import math + +class PositionalEmbedding(nn.Module): + def __init__(self, d_model, max_len=5000): + super(PositionalEmbedding, self).__init__() + # Compute the positional encodings once in log space. + pe = torch.zeros(max_len, d_model).float() + pe.require_grad = False + + position = torch.arange(0, max_len).float().unsqueeze(1) + div_term = (torch.arange(0, d_model, 2).float() + * -(math.log(10000.0) / d_model)).exp() + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + self.register_buffer('pe', pe) + + def forward(self, x): + return self.pe[:, :x.size(1)] class PatchEmbedding(nn.Module): def __init__(self, d_model, patch_len, stride, padding, dropout): diff --git a/model/model_selector.py b/model/model_selector.py index 09b7fdc..5621037 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -30,6 +30,7 @@ from model.ASTRA.astrav3 import ASTRA as ASTRAv3 from model.iTransformer.iTransformer import iTransformer from model.HI.HI import HI from model.PatchTST.PatchTST import Model as PatchTST +from model.MTGNN.MTGNN import gtnet as MTGNN @@ -99,3 +100,5 @@ def model_selector(config): return HI(model_config) case "PatchTST": return PatchTST(model_config) + case "MTGNN": + return MTGNN(model_config) diff --git a/train.py b/train.py index dad4609..7bd72ad 100644 --- a/train.py +++ b/train.py @@ -1,10 +1,27 @@ import yaml import torch +import os import utils.initializer as init from dataloader.loader_selector import get_dataloader from trainer.trainer_selector import select_trainer +def read_config(config_path): + with open(config_path, "r") as file: + config = yaml.safe_load(file) + + # 全局配置 + device = "cuda:0" # 指定设备 + seed = 2023 # 随机种子 + epochs = 100 + + # 拷贝项 + config["basic"]["device"] = device + config["model"]["device"] = device + config["basic"]["seed"] = seed + config["train"]["epochs"] = epochs + return config + def run(config): init.init_seed(config["basic"]["seed"]) model = init.init_model(config) @@ -45,22 +62,26 @@ def run(config): if __name__ == "__main__": # 指定模型 - model_list = ["PatchTST"] + model_list = ["MTGNN"] # 指定数据集 - dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] + dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-Inflow", "BJTaxi-Outflow", "NYCBike-Inflow", "NYCBike-Outflow"] # dataset_list = ["AirQuality"] - device = "cuda:0" # 指定设备 - seed = 2023 # 随机种子 - epochs = 1 + + # 我的调试开关,不做测试就填 str(False) + os.environ["TRY"] = str(False) + for model in model_list: for dataset in dataset_list: config_path = f"./config/{model}/{dataset}.yaml" - with open(config_path, "r") as file: - config = yaml.safe_load(file) - config["basic"]["device"] = device - config["basic"]["seed"] = seed - config["train"]["epochs"] = epochs - print(f"\nRunning {model} on {dataset} with seed {seed} on {device}") - print(f"config: {config}") - run(config) + # 可去这个函数里面调整统一的config项,⚠️注意调设备,epochs + config = read_config(config_path) + print(f"\nRunning {model} on {dataset}") + # print(f"config: {config}") + if os.environ.get("TRY") == "True": + try: + run(config) + except Exception as e: + pass + else: + run(config) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 80a6672..3372873 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -71,6 +71,10 @@ class Trainer: label = target[..., : self.args["output_dim"]] # 计算loss和反归一化loss output = self.model(data) + # 我的调试开关 + if os.environ.get("TRY") == "True": + print(f"[{'✅' if output.shape == label.shape else '❌'}]: output: {output.shape}, label: {label.shape}") + assert False loss = self.loss(output, label) d_output = self.scaler.inverse_transform(output) d_label = self.scaler.inverse_transform(label)