17 lines
478 B
Python
17 lines
478 B
Python
from models.STDEN.stden_model import STDENModel
|
|
from models.STGODE.STGODE import ODEGCN
|
|
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
|
|
|
|
|
|
def model_selector(config):
|
|
model_name = config['basic']['model']
|
|
model = None
|
|
match model_name:
|
|
case 'STDEN':
|
|
model = STDENModel(config)
|
|
case 'STGODE':
|
|
model = ODEGCN(config)
|
|
case 'STGODE-LLM-GPT2':
|
|
model = ODEGCN_LLM_GPT2(config)
|
|
return model
|