Project-I/models/model_selector.py

17 lines
478 B
Python

from models.STDEN.stden_model import STDENModel
from models.STGODE.STGODE import ODEGCN
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
case 'STDEN':
model = STDENModel(config)
case 'STGODE':
model = ODEGCN(config)
case 'STGODE-LLM-GPT2':
model = ODEGCN_LLM_GPT2(config)
return model