from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer from trainer.TSTrainer import Trainer as TSTrainer def select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs ): model_name = args["basic"]["model"] base_args = ( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler ) if model_name in {"HI", "PatchTST", "iTransformer", "FPT"}: return TSTrainer(*base_args) trainer_map = { "DCRNN": DCRNN_Trainer, "PDG2SEQ": PDG2SEQ_Trainer, "STMLP": STMLP_Trainer, "EXP": EXP_Trainer, "Informer": InformerTrainer, } if model_name in {"STGNCDE", "STGNRDE"}: return cdeTrainer(*base_args, kwargs[0], None) return trainer_map.get(model_name, Trainer)(*base_args)