from federatedscope.core.configs.config import CN from federatedscope.register import register_config """ The parameter settings for traffic flow prediction are located in the YAML files under the baseline folder within the trafficflow package. These are only default values. Please modify the specific parameters directly in the YAML files. """ def extend_trafficflow_cfg(cfg): # ---------------------------------------------------------------------- # # Model related options # ---------------------------------------------------------------------- # cfg.model.tfp = CN() cfg.model.tfp.model_num_per_trainer = 1 # some methods may leverage more # than one model in each trainer # cfg.tfp.model.type = 'trafficflow' # cfg.tfp.model.use_bias = True # cfg.tfp.model.task = 'trafficflowprediction' cfg.model.tfp.num_nodes = 0 cfg.model.tfp.rnn_units = 64 cfg.model.tfp.dropout = 0.1 cfg.model.tfp.horizon = 12 cfg.model.tfp.input_dim = 1 # If 0, model will be built by data.shape cfg.model.tfp.output_dim = 1 cfg.model.tfp.embed_dim = 10 cfg.model.tfp.num_layers = 1 # In GPR-GNN, K = layer cfg.model.tfp.cheb_order = 1 # A tuple, e.g., (in_channel, h, w) cfg.model.tfp.use_day = True cfg.model.tfp.use_week = True cfg.model.tfp.minigraph = CN() cfg.model.tfp.minigraph.enable = False cfg.model.tfp.minigraph.size = 5 # ---------------------------------------------------------------------- # # Criterion related options # ---------------------------------------------------------------------- # cfg.criterion = CN() cfg.criterion.type = 'L1Loss' # ---------------------------------------------------------------------- # # regularizer related options # ---------------------------------------------------------------------- # cfg.regularizer = CN() cfg.regularizer.type = '' cfg.regularizer.mu = 0. # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_model_cfg) def assert_model_cfg(cfg): pass register_config("trafficflow", extend_trafficflow_cfg)