111 lines
2.7 KiB
Python
Executable File
111 lines
2.7 KiB
Python
Executable File
from trainer.Trainer import Trainer
|
|
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
|
|
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
|
|
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
|
|
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
|
|
from trainer.E32Trainer import Trainer as EXP_Trainer
|
|
|
|
|
|
def select_trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs,
|
|
):
|
|
model_name = args["basic"]["model"]
|
|
match model_name:
|
|
case "STGNCDE":
|
|
return cdeTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs[0],
|
|
None,
|
|
)
|
|
case "STGNRDE":
|
|
return cdeTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs[0],
|
|
None,
|
|
)
|
|
case "DCRNN":
|
|
return DCRNN_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "PDG2SEQ":
|
|
return PDG2SEQ_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "STMLP":
|
|
return STMLP_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "EXP":
|
|
return EXP_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case _:
|
|
return Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|