from trainer.Trainer import Trainer from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer from trainer.STMLP_Trainer import Trainer as STMLP_Trainer from trainer.E32Trainer import Trainer as EXP_Trainer from trainer.InformerTrainer import InformerTrainer from trainer.TSTrainer import Trainer as TSTrainer def select_trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs, ): model_name = args["basic"]["model"] TS_model = ["HI", "PatchTST", "iTransformer"] if model_name in TS_model: return TSTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) match model_name: case "STGNCDE": return cdeTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None, ) case "STGNRDE": return cdeTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, kwargs[0], None, ) case "DCRNN": return DCRNN_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "PDG2SEQ": return PDG2SEQ_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "STMLP": return STMLP_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "EXP": return EXP_Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case "Informer": return InformerTrainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, ) case _: return Trainer( model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args, lr_scheduler, )