66 lines
2.7 KiB
Python
66 lines
2.7 KiB
Python
from federatedscope.vertical_fl.tree_based_models.trainer \
|
|
import VerticalTrainer, RandomForestTrainer
|
|
import numpy as np
|
|
|
|
|
|
def get_vertical_trainer(config, model, data, device, monitor):
|
|
|
|
if config.model.type.lower() == 'random_forest':
|
|
trainer_cls = RandomForestTrainer
|
|
else:
|
|
trainer_cls = VerticalTrainer
|
|
|
|
protect_object = config.vertical.protect_object
|
|
if not protect_object or protect_object == '':
|
|
return trainer_cls(model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
monitor=monitor)
|
|
elif protect_object == 'feature_order':
|
|
from federatedscope.vertical_fl.tree_based_models.trainer import \
|
|
createFeatureOrderProtectedTrainer
|
|
return createFeatureOrderProtectedTrainer(cls=trainer_cls,
|
|
model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
monitor=monitor)
|
|
elif protect_object in ['grad_and_hess']:
|
|
from federatedscope.vertical_fl.tree_based_models.trainer import \
|
|
createLabelProtectedTrainer
|
|
return createLabelProtectedTrainer(cls=trainer_cls,
|
|
model=model,
|
|
data=data,
|
|
device=device,
|
|
config=config,
|
|
monitor=monitor)
|
|
else:
|
|
raise ValueError
|
|
|
|
|
|
def bucketize(feature_order, bucket_size, bucket_num):
|
|
if isinstance(bucket_size, int):
|
|
remainder = len(feature_order) - bucket_size * bucket_num
|
|
bucket_size = [bucket_size for _ in range(bucket_num)]
|
|
if remainder > 0:
|
|
selected_idx = np.random.choice(a=bucket_num,
|
|
size=remainder,
|
|
replace=False)
|
|
for each in selected_idx:
|
|
bucket_size[each] += 1
|
|
elif remainder < 0:
|
|
selected_idx = np.random.choice(a=bucket_num,
|
|
size=-remainder,
|
|
replace=False)
|
|
for each in selected_idx:
|
|
bucket_size[each] -= 1
|
|
|
|
bucketized_feature_order = list()
|
|
start = 0
|
|
for each_bucket_size in bucket_size:
|
|
end = start + each_bucket_size
|
|
bucketized_feature_order.append(feature_order[start:end])
|
|
start = end
|
|
return bucketized_feature_order
|