FS-TFP/federatedscope/trafficflow/trainer/trafficflow_trainer.py

16 lines
672 B
Python

from federatedscope.register import register_trainer
def call_trafficflow_trainer(config, model, data, device, monitor):
if config.trainer.type == 'trafficflowtrainer':
from federatedscope.trafficflow.trainer.trafficflow import TrafficflowTrainer
Trainer = TrafficflowTrainer(model=model,
scaler=config.data.scaler,
args=config,
data=data,
device=device,
monitor=monitor)
return Trainer
# register_trainer('trafficflowtrainer', call_trafficflow_trainer)