22 lines
1.1 KiB
Python
22 lines
1.1 KiB
Python
from federatedscope.vertical_fl.tree_based_models.model \
|
|
import MultipleXGBTrees, MultipleGBDTTrees, RandomForest
|
|
|
|
|
|
def get_tree_model(model_config, criterion_type=None):
|
|
|
|
if model_config.type.lower() == 'xgb_tree':
|
|
return MultipleXGBTrees(max_depth=model_config.max_tree_depth,
|
|
lambda_=model_config.lambda_,
|
|
gamma=model_config.gamma,
|
|
num_of_trees=model_config.num_of_trees)
|
|
elif model_config.type.lower() == 'gbdt_tree':
|
|
return MultipleGBDTTrees(max_depth=model_config.max_tree_depth,
|
|
lambda_=model_config.lambda_,
|
|
gamma=model_config.gamma,
|
|
num_of_trees=model_config.num_of_trees)
|
|
elif model_config.type.lower() == 'random_forest':
|
|
return RandomForest(max_depth=model_config.max_tree_depth,
|
|
lambda_=model_config.lambda_,
|
|
gamma=model_config.gamma,
|
|
num_of_trees=model_config.num_of_trees)
|