FS-TFP/federatedscope/core/configs/cfg_trafficflow.py

58 lines
1.9 KiB
Python

from federatedscope.core.configs.config import CN
from federatedscope.register import register_config
"""
The parameter settings for traffic flow prediction are located in the YAML files under
the baseline folder within the trafficflow package. These are only default values.
Please modify the specific parameters directly in the YAML files.
"""
def extend_trafficflow_cfg(cfg):
# ---------------------------------------------------------------------- #
# Model related options
# ---------------------------------------------------------------------- #
cfg.model = CN()
cfg.model.model_num_per_trainer = 1 # some methods may leverage more
# than one model in each trainer
cfg.model.type = 'trafficflow'
cfg.model.use_bias = True
cfg.model.task = 'trafficflowprediction'
cfg.model.num_nodes = 0
cfg.model.rnn_units = 64
cfg.model.dropout = 0.1
cfg.model.horizon = 12
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
cfg.model.output_dim = 1
cfg.model.embed_dim = 10
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
cfg.model.use_day = True
cfg.model.use_week = True
# ---------------------------------------------------------------------- #
# Criterion related options
# ---------------------------------------------------------------------- #
cfg.criterion = CN()
cfg.criterion.type = 'L1Loss'
# ---------------------------------------------------------------------- #
# regularizer related options
# ---------------------------------------------------------------------- #
cfg.regularizer = CN()
cfg.regularizer.type = ''
cfg.regularizer.mu = 0.
# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_model_cfg)
def assert_model_cfg(cfg):
pass
register_config("trafficflow", extend_trafficflow_cfg)