37 lines
1.4 KiB
Python
37 lines
1.4 KiB
Python
from model.DDGCRN.DDGCRN import DDGCRN
|
|
from model.TWDGCN.TWDGCN import TWDGCN
|
|
from model.AGCRN.AGCRN import AGCRN
|
|
from model.NLT.HierAttnLstm import HierAttnLstm
|
|
from model.STGNCDE.Make_model import make_model
|
|
from model.DSANET.DSANET import DSANet
|
|
from model.STGCN.models import STGCNChebGraphConv
|
|
from model.DCRNN.dcrnn_model import DCRNNModel
|
|
from model.ARIMA.ARIMA import ARIMA
|
|
from model.TCN.TCN import TemporalConvNet
|
|
from model.GWN.GraphWaveNet import gwnet
|
|
from model.STFGNN.STFGNN import STFGNN
|
|
from model.STSGCN.STSGCN import STSGCN
|
|
from model.STGODE.STGODE import ODEGCN
|
|
from model.PDG2SEQ.PDG2Seq import PDG2Seq
|
|
from model.EXP.EXP9 import EXP as EXP
|
|
|
|
def model_selector(model):
|
|
match model['type']:
|
|
case 'DDGCRN': return DDGCRN(model)
|
|
case 'TWDGCN': return TWDGCN(model)
|
|
case 'AGCRN': return AGCRN(model)
|
|
case 'NLT': return HierAttnLstm(model)
|
|
case 'STGNCDE': return make_model(model)
|
|
case 'DSANET': return DSANet(model)
|
|
case 'STGCN': return STGCNChebGraphConv(model)
|
|
case 'DCRNN': return DCRNNModel(model)
|
|
case 'ARIMA': return ARIMA(model)
|
|
case 'TCN': return TemporalConvNet(model)
|
|
case 'GWN': return gwnet(model)
|
|
case 'STFGNN': return STFGNN(model)
|
|
case 'STSGCN': return STSGCN(model)
|
|
case 'STGODE': return ODEGCN(model)
|
|
case 'PDG2SEQ': return PDG2Seq(model)
|
|
case 'EXP': return EXP(model)
|
|
|