TrafficWheel/model/model_selector.py

39 lines
1.5 KiB
Python

from model.DDGCRN.DDGCRN import DDGCRN
from model.TWDGCN.TWDGCN import TWDGCN
from model.AGCRN.AGCRN import AGCRN
from model.NLT.HierAttnLstm import HierAttnLstm
from model.STGNCDE.Make_model import make_model
from model.DSANET.DSANET import DSANet
from model.STGCN.models import STGCNChebGraphConv
from model.DCRNN.dcrnn_model import DCRNNModel
from model.ARIMA.ARIMA import ARIMA
from model.TCN.TCN import TemporalConvNet
from model.GWN.GraphWaveNet import gwnet
from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
from model.PDG2SEQ.PDG2Seq import PDG2Seq
from model.STID.STID import STID
from model.EXP.EXP26 import EXP as EXP
def model_selector(model):
match model['type']:
case 'DDGCRN': return DDGCRN(model)
case 'TWDGCN': return TWDGCN(model)
case 'AGCRN': return AGCRN(model)
case 'NLT': return HierAttnLstm(model)
case 'STGNCDE': return make_model(model)
case 'DSANET': return DSANet(model)
case 'STGCN': return STGCNChebGraphConv(model)
case 'DCRNN': return DCRNNModel(model)
case 'ARIMA': return ARIMA(model)
case 'TCN': return TemporalConvNet(model)
case 'GWN': return gwnet(model)
case 'STFGNN': return STFGNN(model)
case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model)
case 'PDG2SEQ': return PDG2Seq(model)
case 'STID': return STID(model)
case 'EXP': return EXP(model)