from model.DDGCRN.DDGCRN import DDGCRN from model.TWDGCN.TWDGCN import TWDGCN from model.AGCRN.AGCRN import AGCRN from model.NLT.HierAttnLstm import HierAttnLstm from model.STGNCDE.Make_model import make_model from model.DSANET.DSANET import DSANet from model.STGCN.models import STGCNChebGraphConv from model.DCRNN.dcrnn_model import DCRNNModel from model.ARIMA.ARIMA import ARIMA from model.TCN.TCN import TemporalConvNet from model.GWN.GraphWaveNet import gwnet from model.STFGNN.STFGNN import STFGNN from model.STSGCN.STSGCN import STSGCN from model.STGODE.STGODE import ODEGCN def model_selector(model): match model['type']: case 'DDGCRN': return DDGCRN(model) case 'TWDGCN': return TWDGCN(model) case 'AGCRN': return AGCRN(model) case 'NLT': return HierAttnLstm(model) case 'STGNCDE': return make_model(model) case 'DSANET': return DSANet(model) case 'STGCN': return STGCNChebGraphConv(model) case 'DCRNN': return DCRNNModel(model) case 'ARIMA': return ARIMA(model) case 'TCN': return TemporalConvNet(model) case 'GWN': return gwnet(model) case 'STFGNN': return STFGNN(model) case 'STSGCN': return STSGCN(model) case 'STGODE': return ODEGCN(model)