Project-I/trainer/trainer_selector.py

11 lines
521 B
Python

from trainer.trainer import Trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs):
model_name = config['basic']['model']
selected_Trainer = None
match model_name:
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,lr_scheduler)
if selected_Trainer is None: raise NotImplementedError
return selected_Trainer