FS-TFP/federatedscope/vertical_fl/tree_based_models/model/model_builder.py

22 lines
1.1 KiB
Python

from federatedscope.vertical_fl.tree_based_models.model \
import MultipleXGBTrees, MultipleGBDTTrees, RandomForest
def get_tree_model(model_config, criterion_type=None):
if model_config.type.lower() == 'xgb_tree':
return MultipleXGBTrees(max_depth=model_config.max_tree_depth,
lambda_=model_config.lambda_,
gamma=model_config.gamma,
num_of_trees=model_config.num_of_trees)
elif model_config.type.lower() == 'gbdt_tree':
return MultipleGBDTTrees(max_depth=model_config.max_tree_depth,
lambda_=model_config.lambda_,
gamma=model_config.gamma,
num_of_trees=model_config.num_of_trees)
elif model_config.type.lower() == 'random_forest':
return RandomForest(max_depth=model_config.max_tree_depth,
lambda_=model_config.lambda_,
gamma=model_config.gamma,
num_of_trees=model_config.num_of_trees)