from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs): match args['model']['type']: case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler, kwargs[0], None) case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler) case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'], lr_scheduler)