|
from models.STDEN.stden_model import STDENModel
|
|
from models.STGODE.STGODE import ODEGCN
|
|
|
|
def model_selector(config):
|
|
model_name = config['basic']['model']
|
|
model = None
|
|
match model_name:
|
|
case 'STDEN':
|
|
model = STDENModel(config)
|
|
case 'STGODE':
|
|
model = ODEGCN(config)
|
|
return model |