from trainer.trainer import Trainer from trainer.ode_trainer import Trainer as ode_trainer def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs): model_name = config['basic']['model'] selected_Trainer = None match model_name: case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler) case _: selected_Trainer = Trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler) if selected_Trainer is None: raise NotImplementedError return selected_Trainer