58 lines
1.9 KiB
Python
58 lines
1.9 KiB
Python
from federatedscope.core.configs.config import CN
|
|
from federatedscope.register import register_config
|
|
|
|
"""
|
|
The parameter settings for traffic flow prediction are located in the YAML files under
|
|
the baseline folder within the trafficflow package. These are only default values.
|
|
Please modify the specific parameters directly in the YAML files.
|
|
"""
|
|
|
|
|
|
def extend_trafficflow_cfg(cfg):
|
|
# ---------------------------------------------------------------------- #
|
|
# Model related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.model = CN()
|
|
|
|
cfg.model.model_num_per_trainer = 1 # some methods may leverage more
|
|
# than one model in each trainer
|
|
cfg.model.type = 'trafficflow'
|
|
cfg.model.use_bias = True
|
|
cfg.model.task = 'trafficflowprediction'
|
|
cfg.model.num_nodes = 0
|
|
cfg.model.rnn_units = 64
|
|
cfg.model.dropout = 0.1
|
|
cfg.model.horizon = 12
|
|
cfg.model.input_dim = 1 # If 0, model will be built by data.shape
|
|
cfg.model.output_dim = 1
|
|
cfg.model.embed_dim = 10
|
|
cfg.model.num_layers = 1 # In GPR-GNN, K = layer
|
|
cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w)
|
|
cfg.model.use_day = True
|
|
cfg.model.use_week = True
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# Criterion related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.criterion = CN()
|
|
|
|
cfg.criterion.type = 'L1Loss'
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# regularizer related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.regularizer = CN()
|
|
|
|
cfg.regularizer.type = ''
|
|
cfg.regularizer.mu = 0.
|
|
|
|
# --------------- register corresponding check function ----------
|
|
cfg.register_cfg_check_fun(assert_model_cfg)
|
|
|
|
|
|
def assert_model_cfg(cfg):
|
|
pass
|
|
|
|
|
|
register_config("trafficflow", extend_trafficflow_cfg)
|