from model.DDGCRN.DDGCRN import DDGCRN from model.TWDGCN.TWDGCN import TWDGCN from model.AGCRN.AGCRN import AGCRN from model.NLT.HierAttnLstm import HierAttnLstm from model.STGNCDE.Make_model import make_model from model.DSANET.DSANET import DSANet from model.STGCN.models import STGCNChebGraphConv from model.DCRNN.dcrnn_model import DCRNNModel from model.ARIMA.ARIMA import ARIMA from model.TCN.TCN import TemporalConvNet from model.GWN.GraphWaveNet import gwnet from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN from model.PDG2SEQ.PDG2Seq import PDG2Seq from model.EXP.EXP import EXP from model.EXPB.EXP_b import EXPB def model_selector(model): match model['type']: case 'DDGCRN': return DDGCRN(model) case 'TWDGCN': return TWDGCN(model) case 'AGCRN': return AGCRN(model) case 'NLT': return HierAttnLstm(model) case 'STGNCDE': return make_model(model) case 'DSANET': return DSANet(model) case 'STGCN': return STGCNChebGraphConv(model) case 'DCRNN': return DCRNNModel(model) case 'ARIMA': return ARIMA(model) case 'TCN': return TemporalConvNet(model) case 'GWN': return gwnet(model) case 'STFGNN': return STFGNN(model) case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model) case 'PDG2SEQ': return PDG2Seq(model) case 'EXP': return EXP(model) case 'EXPB': return EXPB(model)