TrafficWheel/trainer/trainer_selector.py

21 lines
1.3 KiB
Python
Executable File

from trainer.Trainer import Trainer
from trainer.cdeTrainer.cdetrainer import Trainer as cdeTrainer
from trainer.DCRNN_Trainer import Trainer as DCRNN_Trainer
from trainer.PDG2SEQ_Trainer import Trainer as PDG2SEQ_Trainer
from trainer.EXP_trainer import Trainer as EXP_Trainer
def select_trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args,
lr_scheduler, kwargs):
match args['model']['type']:
case "STGNCDE": return cdeTrainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler, kwargs[0], None)
case 'DCRNN': return DCRNN_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case 'PDG2SEQ': return PDG2SEQ_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case 'EXP': return EXP_Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)
case _: return Trainer(model, loss, optimizer, train_loader, val_loader, test_loader, scaler, args['train'],
lr_scheduler)