from federatedscope.vertical_fl.tree_based_models.model \ import MultipleXGBTrees, MultipleGBDTTrees, RandomForest def get_tree_model(model_config, criterion_type=None): if model_config.type.lower() == 'xgb_tree': return MultipleXGBTrees(max_depth=model_config.max_tree_depth, lambda_=model_config.lambda_, gamma=model_config.gamma, num_of_trees=model_config.num_of_trees) elif model_config.type.lower() == 'gbdt_tree': return MultipleGBDTTrees(max_depth=model_config.max_tree_depth, lambda_=model_config.lambda_, gamma=model_config.gamma, num_of_trees=model_config.num_of_trees) elif model_config.type.lower() == 'random_forest': return RandomForest(max_depth=model_config.max_tree_depth, lambda_=model_config.lambda_, gamma=model_config.gamma, num_of_trees=model_config.num_of_trees)