55 lines
2.4 KiB
Python
Executable File
55 lines
2.4 KiB
Python
Executable File
from model.DDGCRN.DDGCRN import DDGCRN
|
|
from model.TWDGCN.TWDGCN import TWDGCN
|
|
from model.AGCRN.AGCRN import AGCRN
|
|
from model.NLT.HierAttnLstm import HierAttnLstm
|
|
from model.STGNCDE.Make_model import make_model
|
|
from model.DSANET.DSANET import DSANet
|
|
from model.STGCN.models import STGCNChebGraphConv
|
|
from model.DCRNN.dcrnn_model import DCRNNModel
|
|
from model.ARIMA.ARIMA import ARIMA
|
|
from model.TCN.TCN import TemporalConvNet
|
|
from model.GWN.GraphWaveNet import gwnet
|
|
from model.STFGNN.STFGNN import STFGNN
|
|
from model.STSGCN.STSGCN import STSGCN
|
|
from model.STGODE.STGODE import ODEGCN
|
|
from model.PDG2SEQ.PDG2Seqb import PDG2Seq
|
|
from model.STMLP.STMLP import STMLP
|
|
from model.STIDGCN.STIDGCN import STIDGCN
|
|
from model.STID.STID import STID
|
|
from model.STAEFormer.STAEFormer import STAEformer
|
|
from model.EXP.EXP32 import EXP as EXP
|
|
from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
|
from model.ST_SSL.ST_SSL import STSSLModel
|
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
|
from model.STAWnet.STAWnet import STAWnet
|
|
|
|
def model_selector(config):
|
|
model_name = config["basic"]["model"]
|
|
model_config = config["model"]
|
|
match model_name:
|
|
case 'DDGCRN': return DDGCRN(model_config)
|
|
case 'TWDGCN': return TWDGCN(model_config)
|
|
case 'AGCRN': return AGCRN(model_config)
|
|
case 'NLT': return HierAttnLstm(model_config)
|
|
case 'STGNCDE': return make_model(model_config)
|
|
case 'DSANET': return DSANet(model_config)
|
|
case 'STGCN': return STGCNChebGraphConv(model_config)
|
|
case 'DCRNN': return DCRNNModel(model_config)
|
|
case 'ARIMA': return ARIMA(model_config)
|
|
case 'TCN': return TemporalConvNet(model_config)
|
|
case 'GWN': return gwnet(model_config)
|
|
case 'STFGNN': return STFGNN(model_config)
|
|
case 'STSGCN': return STSGCN(model_config)
|
|
case 'STGODE': return ODEGCN(model_config)
|
|
case 'PDG2SEQ': return PDG2Seq(model_config)
|
|
case 'STMLP': return STMLP(model_config)
|
|
case 'STIDGCN': return STIDGCN(model_config)
|
|
case 'STID': return STID(model_config)
|
|
case 'STAEFormer': return STAEformer(model_config)
|
|
case 'EXP': return EXP(model_config)
|
|
case 'MegaCRN': return MegaCRNModel(model_config)
|
|
case 'ST_SSL': return STSSLModel(model_config)
|
|
case 'STGNRDE': return make_nrde_model(model_config)
|
|
case 'STAWnet': return STAWnet(model_config)
|
|
|