TrafficWheel/trainer/trainer_selector.py

111 lines
2.7 KiB
Python
Executable File

from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
from trainer.E32Trainer import Trainer as EXP_Trainer
def select_trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs,
):
model_name = args["basic"]["model"]
match model_name:
case "STGNCDE":
return cdeTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs[0],
None,
)
case "STGNRDE":
return cdeTrainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
kwargs[0],
None,
)
case "DCRNN":
return DCRNN_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "PDG2SEQ":
return PDG2SEQ_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "STMLP":
return STMLP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case "EXP":
return EXP_Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)
case _:
return Trainer(
model,
loss,
optimizer,
train_loader,
val_loader,
test_loader,
scaler,
args,
lr_scheduler,
)