TrafficWheel/model/model_selector.py

89 lines
3.0 KiB
Python
Executable File

from model.DDGCRN.DDGCRN import DDGCRN
from model.TWDGCN.TWDGCN import TWDGCN
from model.AGCRN.AGCRN import AGCRN
from model.NLT.HierAttnLstm import HierAttnLstm
from model.STGNCDE.Make_model import make_model
from model.DSANET.DSANET import DSANet
from model.STGCN.models import STGCNChebGraphConv
from model.DCRNN.dcrnn_model import DCRNNModel
from model.ARIMA.ARIMA import ARIMA
from model.TCN.TCN import TemporalConvNet
from model.GWN.GraphWaveNet import gwnet
from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
from model.STMLP.STMLP import STMLP
from model.STIDGCN.STIDGCN import STIDGCN
from model.STID.STID import STID
from model.STAEFormer.STAEFormer import STAEformer
from model.EXP.EXP32 import EXP as EXP
from model.MegaCRN.MegaCRNModel import MegaCRNModel
from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST
from model.AEPSA.aepsa import AEPSA as AEPSA
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2
def model_selector(config):
model_name = config["basic"]["model"]
model_config = config["model"]
match model_name:
case "DDGCRN":
return DDGCRN(model_config)
case "TWDGCN":
return TWDGCN(model_config)
case "AGCRN":
return AGCRN(model_config)
case "NLT":
return HierAttnLstm(model_config)
case "STGNCDE":
return make_model(model_config)
case "DSANET":
return DSANet(model_config)
case "STGCN":
return STGCNChebGraphConv(model_config)
case "DCRNN":
return DCRNNModel(model_config)
case "ARIMA":
return ARIMA(model_config)
case "TCN":
return TemporalConvNet(model_config)
case "GWN":
return gwnet(model_config)
case "STFGNN":
return STFGNN(model_config)
case "STSGCN":
return STSGCN(model_config)
case "STGODE":
return ODEGCN(model_config)
case "PDG2SEQ":
return PDG2Seq(model_config)
case "STMLP":
return STMLP(model_config)
case "STIDGCN":
return STIDGCN(model_config)
case "STID":
return STID(model_config)
case "STAEFormer":
return STAEformer(model_config)
case "EXP":
return EXP(model_config)
case "MegaCRN":
return MegaCRNModel(model_config)
case "ST_SSL":
return STSSLModel(model_config)
case "STGNRDE":
return make_nrde_model(model_config)
case "STAWnet":
return STAWnet(model_config)
case "REPST":
return REPST(model_config)
case "AEPSA":
return AEPSA(model_config)
case "AEPSA_v2":
return AEPSAv2(model_config)