from federatedscope.core.configs.config import CN from federatedscope.register import register_config def extend_model_cfg(cfg): # ---------------------------------------------------------------------- # # Model related options # ---------------------------------------------------------------------- # cfg.model = CN() cfg.model.model_num_per_trainer = 1 # some methods may leverage more # than one model in each trainer cfg.model.type = 'lr' cfg.model.use_bias = True cfg.model.task = 'node' cfg.model.hidden = 256 cfg.model.dropout = 0.5 cfg.model.in_channels = 0 # If 0, model will be built by data.shape cfg.model.out_channels = 1 cfg.model.layer = 2 # In GPR-GNN, K = layer cfg.model.graph_pooling = 'mean' cfg.model.embed_size = 8 cfg.model.num_item = 0 cfg.model.num_user = 0 cfg.model.input_shape = () # A tuple, e.g., (in_channel, h, w) # For tree-based model cfg.model.lambda_ = 0.1 cfg.model.gamma = 0 cfg.model.num_of_trees = 10 cfg.model.max_tree_depth = 3 # language model for hetero NLP tasks cfg.model.stage = '' # ['assign', 'contrast'] cfg.model.model_type = 'google/bert_uncased_L-2_H-128_A-2' cfg.model.pretrain_tasks = [] cfg.model.downstream_tasks = [] cfg.model.num_labels = 1 cfg.model.max_length = 200 cfg.model.min_length = 1 cfg.model.no_repeat_ngram_size = 3 cfg.model.length_penalty = 2.0 cfg.model.num_beams = 5 cfg.model.label_smoothing = 0.1 cfg.model.n_best_size = 20 cfg.model.max_answer_len = 30 cfg.model.null_score_diff_threshold = 0.0 cfg.model.use_contrastive_loss = False cfg.model.contrast_topk = 100 cfg.model.contrast_temp = 1.0 # Traffic Flow model parameters, These are only default values. # Please modify the specific parameters directly in the baselines/YAML files. cfg.model.num_nodes = 0 cfg.model.rnn_units = 64 cfg.model.dropout = 0.1 cfg.model.horizon = 12 cfg.model.input_dim = 1 # If 0, model will be built by data.shape cfg.model.output_dim = 1 cfg.model.embed_dim = 10 cfg.model.num_layers = 1 # In GPR-GNN, K = layer cfg.model.cheb_order = 1 # A tuple, e.g., (in_channel, h, w) cfg.model.use_day = True cfg.model.use_week = True # ---------------------------------------------------------------------- # # Criterion related options # ---------------------------------------------------------------------- # cfg.criterion = CN() cfg.criterion.type = 'MSELoss' # ---------------------------------------------------------------------- # # regularizer related options # ---------------------------------------------------------------------- # cfg.regularizer = CN() cfg.regularizer.type = '' cfg.regularizer.mu = 0. # --------------- register corresponding check function ---------- cfg.register_cfg_check_fun(assert_model_cfg) def assert_model_cfg(cfg): pass register_config("model", extend_model_cfg)