21 lines
1.3 KiB
Python
Executable File
21 lines
1.3 KiB
Python
Executable File
from trainer.Trainer import Trainer
|
|
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
|
|
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
|
|
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
|
|
from trainer.EXP_trainer import Trainer as EXP_Trainer
|
|
|
|
|
|
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
|
|
lr_scheduler, kwargs):
|
|
match args['model']['type']:
|
|
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
|
lr_scheduler, kwargs[0], None)
|
|
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
|
lr_scheduler)
|
|
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
|
lr_scheduler)
|
|
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
|
lr_scheduler)
|
|
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
|
|
lr_scheduler)
|