from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer def select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs, ): model_name = args["basic"]["model"] match model_name: case "STGNCDE": return cdeTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None, ) case "STGNRDE": return cdeTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None, ) case "DCRNN": return DCRNN_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "PDG2SEQ": return PDG2SEQ_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "STMLP": return STMLP_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "EXP": return EXP_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case _: return Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, )