16 lines
672 B
Python
16 lines
672 B
Python
from federatedscope.register import register_trainer
|
|
|
|
def call_trafficflow_trainer(config, model, data, device, monitor):
|
|
if config.trainer.type == 'trafficflowtrainer':
|
|
from federatedscope.trafficflow.trainer.trafficflow import TrafficflowTrainer
|
|
Trainer = TrafficflowTrainer(model=model,
|
|
scaler=config.data.scaler,
|
|
args=config,
|
|
data=data,
|
|
device=device,
|
|
monitor=monitor)
|
|
return Trainer
|
|
|
|
|
|
# register_trainer('trafficflowtrainer', call_trafficflow_trainer)
|