Project-I/trainer/trainer_selector.py

14 lines
754 B
Python

from trainer.trainer import Trainer
from trainer.ode_trainer import Trainer as ode_trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs):
model_name = config['basic']['model']
selected_Trainer = None
match model_name:
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler)
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler)
if selected_Trainer is None: raise NotImplementedError
return selected_Trainer