38 lines
1.2 KiB
Python
Executable File
38 lines
1.2 KiB
Python
Executable File
from trainer.Trainer import Trainer
|
|
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
|
|
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
|
|
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
|
|
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
|
|
from trainer.E32Trainer import Trainer as EXP_Trainer
|
|
from trainer.InformerTrainer import InformerTrainer
|
|
from trainer.TSTrainer import Trainer as TSTrainer
|
|
|
|
|
|
def select_trainer(
|
|
model, loss, optimizer,
|
|
train_loader, val_loader, test_loader,
|
|
scaler, args, lr_scheduler, kwargs
|
|
):
|
|
model_name = args["basic"]["model"]
|
|
base_args = (
|
|
model, loss, optimizer,
|
|
train_loader, val_loader, test_loader,
|
|
scaler, args, lr_scheduler
|
|
)
|
|
|
|
if model_name in {"HI", "PatchTST", "iTransformer"}:
|
|
return TSTrainer(*base_args)
|
|
|
|
trainer_map = {
|
|
"DCRNN": DCRNN_Trainer,
|
|
"PDG2SEQ": PDG2SEQ_Trainer,
|
|
"STMLP": STMLP_Trainer,
|
|
"EXP": EXP_Trainer,
|
|
"Informer": InformerTrainer,
|
|
}
|
|
|
|
if model_name in {"STGNCDE", "STGNRDE"}:
|
|
return cdeTrainer(*base_args, kwargs[0], None)
|
|
|
|
return trainer_map.get(model_name, Trainer)(*base_args)
|