14 lines
754 B
Python
14 lines
754 B
Python
from trainer.trainer import Trainer
|
|
from trainer.ode_trainer import Trainer as ode_trainer
|
|
|
|
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
|
lr_scheduler, kwargs):
|
|
model_name = config['basic']['model']
|
|
selected_Trainer = None
|
|
match model_name:
|
|
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
|
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
|
if selected_Trainer is None: raise NotImplementedError
|
|
return selected_Trainer |