FS-TFP/federatedscope/vertical_fl/tree_based_models/trainer/utils.py

66 lines
2.7 KiB
Python

from federatedscope.vertical_fl.tree_based_models.trainer \
import VerticalTrainer, RandomForestTrainer
import numpy as np
def get_vertical_trainer(config, model, data, device, monitor):
if config.model.type.lower() == 'random_forest':
trainer_cls = RandomForestTrainer
else:
trainer_cls = VerticalTrainer
protect_object = config.vertical.protect_object
if not protect_object or protect_object == '':
return trainer_cls(model=model,
data=data,
device=device,
config=config,
monitor=monitor)
elif protect_object == 'feature_order':
from federatedscope.vertical_fl.tree_based_models.trainer import \
createFeatureOrderProtectedTrainer
return createFeatureOrderProtectedTrainer(cls=trainer_cls,
model=model,
data=data,
device=device,
config=config,
monitor=monitor)
elif protect_object in ['grad_and_hess']:
from federatedscope.vertical_fl.tree_based_models.trainer import \
createLabelProtectedTrainer
return createLabelProtectedTrainer(cls=trainer_cls,
model=model,
data=data,
device=device,
config=config,
monitor=monitor)
else:
raise ValueError
def bucketize(feature_order, bucket_size, bucket_num):
if isinstance(bucket_size, int):
remainder = len(feature_order) - bucket_size * bucket_num
bucket_size = [bucket_size for _ in range(bucket_num)]
if remainder > 0:
selected_idx = np.random.choice(a=bucket_num,
size=remainder,
replace=False)
for each in selected_idx:
bucket_size[each] += 1
elif remainder < 0:
selected_idx = np.random.choice(a=bucket_num,
size=-remainder,
replace=False)
for each in selected_idx:
bucket_size[each] -= 1
bucketized_feature_order = list()
start = 0
for each_bucket_size in bucket_size:
end = start + each_bucket_size
bucketized_feature_order.append(feature_order[start:end])
start = end
return bucketized_feature_order