139 lines
3.5 KiB
Python
Executable File
139 lines
3.5 KiB
Python
Executable File
from trainer.Trainer import Trainer
|
|
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
|
|
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
|
|
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
|
|
from trainer.STMLP_Trainer import Trainer as STMLP_Trainer
|
|
from trainer.E32Trainer import Trainer as EXP_Trainer
|
|
from trainer.InformerTrainer import InformerTrainer
|
|
from trainer.TSTrainer import Trainer as TSTrainer
|
|
|
|
def select_trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs,
|
|
):
|
|
model_name = args["basic"]["model"]
|
|
TS_model = ["HI", "PatchTST", "iTransformer"]
|
|
if model_name in TS_model:
|
|
return TSTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
|
|
|
|
match model_name:
|
|
case "STGNCDE":
|
|
return cdeTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs[0],
|
|
None,
|
|
)
|
|
case "STGNRDE":
|
|
return cdeTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
kwargs[0],
|
|
None,
|
|
)
|
|
case "DCRNN":
|
|
return DCRNN_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "PDG2SEQ":
|
|
return PDG2SEQ_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "STMLP":
|
|
return STMLP_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "EXP":
|
|
return EXP_Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case "Informer":
|
|
return InformerTrainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|
|
case _:
|
|
return Trainer(
|
|
model,
|
|
loss,
|
|
optimizer,
|
|
train_loader,
|
|
val_loader,
|
|
test_loader,
|
|
scaler,
|
|
args,
|
|
lr_scheduler,
|
|
)
|