TrafficWheel/model/model_selector.py

33 lines
1.2 KiB
Python

from model.DDGCRN.DDGCRN import DDGCRN
from model.TWDGCN.TWDGCN import TWDGCN
from model.AGCRN.AGCRN import AGCRN
from model.NLT.HierAttnLstm import HierAttnLstm
from model.STGNCDE.Make_model import make_model
from model.DSANET.DSANET import DSANet
from model.STGCN.models import STGCNChebGraphConv
from model.DCRNN.dcrnn_model import DCRNNModel
from model.ARIMA.ARIMA import ARIMA
from model.TCN.TCN import TemporalConvNet
from model.GWN.GraphWaveNet import gwnet
from model.STFGNN.STFGNN import STFGNN
from model.STSGCN.STSGCN import STSGCN
from model.STGODE.STGODE import ODEGCN
def model_selector(model):
match model['type']:
case 'DDGCRN': return DDGCRN(model)
case 'TWDGCN': return TWDGCN(model)
case 'AGCRN': return AGCRN(model)
case 'NLT': return HierAttnLstm(model)
case 'STGNCDE': return make_model(model)
case 'DSANET': return DSANet(model)
case 'STGCN': return STGCNChebGraphConv(model)
case 'DCRNN': return DCRNNModel(model)
case 'ARIMA': return ARIMA(model)
case 'TCN': return TemporalConvNet(model)
case 'GWN': return gwnet(model)
case 'STFGNN': return STFGNN(model)
case 'STSGCN': return STSGCN(model)
case 'STGODE': return ODEGCN(model)