TrafficWheel/trainer/trainer_selector.py

38 lines
1.2 KiB
Python
Executable File

from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
from trainer.E32Trainer import Trainer as EXP_Trainer
from trainer.InformerTrainer import InformerTrainer
from trainer.TSTrainer import Trainer as TSTrainer
def select_trainer(
model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler, kwargs
):
model_name = args["basic"]["model"]
base_args = (
model, loss, optimizer,
train_loader, val_loader, test_loader,
scaler, args, lr_scheduler
)
if model_name in {"HI", "PatchTST", "iTransformer"}:
return TSTrainer(*base_args)
trainer_map = {
"DCRNN": DCRNN_Trainer,
"PDG2SEQ": PDG2SEQ_Trainer,
"STMLP": STMLP_Trainer,
"EXP": EXP_Trainer,
"Informer": InformerTrainer,
}
if model_name in {"STGNCDE", "STGNRDE"}:
return cdeTrainer(*base_args, kwargs[0], None)
return trainer_map.get(model_name, Trainer)(*base_args)