89 lines
3.0 KiB
Python
Executable File
89 lines
3.0 KiB
Python
Executable File
from model.DDGCRN.DDGCRN import DDGCRN
|
|
from model.TWDGCN.TWDGCN import TWDGCN
|
|
from model.AGCRN.AGCRN import AGCRN
|
|
from model.NLT.HierAttnLstm import HierAttnLstm
|
|
from model.STGNCDE.Make_model import make_model
|
|
from model.DSANET.DSANET import DSANet
|
|
from model.STGCN.models import STGCNChebGraphConv
|
|
from model.DCRNN.dcrnn_model import DCRNNModel
|
|
from model.ARIMA.ARIMA import ARIMA
|
|
from model.TCN.TCN import TemporalConvNet
|
|
from model.GWN.GraphWaveNet import gwnet
|
|
from model.STFGNN.STFGNN import STFGNN
|
|
from model.STSGCN.STSGCN import STSGCN
|
|
from model.STGODE.STGODE import ODEGCN
|
|
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
|
|
from model.STMLP.STMLP import STMLP
|
|
from model.STIDGCN.STIDGCN import STIDGCN
|
|
from model.STID.STID import STID
|
|
from model.STAEFormer.STAEFormer import STAEformer
|
|
from model.EXP.EXP32 import EXP as EXP
|
|
from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
|
from model.ST_SSL.ST_SSL import STSSLModel
|
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
|
from model.STAWnet.STAWnet import STAWnet
|
|
from model.REPST.repst import repst as REPST
|
|
from model.AEPSA.aepsa import AEPSA as AEPSA
|
|
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2
|
|
|
|
|
|
|
|
def model_selector(config):
|
|
model_name = config["basic"]["model"]
|
|
model_config = config["model"]
|
|
match model_name:
|
|
case "DDGCRN":
|
|
return DDGCRN(model_config)
|
|
case "TWDGCN":
|
|
return TWDGCN(model_config)
|
|
case "AGCRN":
|
|
return AGCRN(model_config)
|
|
case "NLT":
|
|
return HierAttnLstm(model_config)
|
|
case "STGNCDE":
|
|
return make_model(model_config)
|
|
case "DSANET":
|
|
return DSANet(model_config)
|
|
case "STGCN":
|
|
return STGCNChebGraphConv(model_config)
|
|
case "DCRNN":
|
|
return DCRNNModel(model_config)
|
|
case "ARIMA":
|
|
return ARIMA(model_config)
|
|
case "TCN":
|
|
return TemporalConvNet(model_config)
|
|
case "GWN":
|
|
return gwnet(model_config)
|
|
case "STFGNN":
|
|
return STFGNN(model_config)
|
|
case "STSGCN":
|
|
return STSGCN(model_config)
|
|
case "STGODE":
|
|
return ODEGCN(model_config)
|
|
case "PDG2SEQ":
|
|
return PDG2Seq(model_config)
|
|
case "STMLP":
|
|
return STMLP(model_config)
|
|
case "STIDGCN":
|
|
return STIDGCN(model_config)
|
|
case "STID":
|
|
return STID(model_config)
|
|
case "STAEFormer":
|
|
return STAEformer(model_config)
|
|
case "EXP":
|
|
return EXP(model_config)
|
|
case "MegaCRN":
|
|
return MegaCRNModel(model_config)
|
|
case "ST_SSL":
|
|
return STSSLModel(model_config)
|
|
case "STGNRDE":
|
|
return make_nrde_model(model_config)
|
|
case "STAWnet":
|
|
return STAWnet(model_config)
|
|
case "REPST":
|
|
return REPST(model_config)
|
|
case "AEPSA":
|
|
return AEPSA(model_config)
|
|
case "AEPSA_v2":
|
|
return AEPSAv2(model_config)
|