from federatedscope.vertical_fl.tree_based_models.trainer \ import VerticalTrainer, RandomForestTrainer import numpy as np def get_vertical_trainer(config, model, data, device, monitor): if config.model.type.lower() == 'random_forest': trainer_cls = RandomForestTrainer else: trainer_cls = VerticalTrainer protect_object = config.vertical.protect_object if not protect_object or protect_object == '': return trainer_cls(model=model, data=data, device=device, config=config, monitor=monitor) elif protect_object == 'feature_order': from federatedscope.vertical_fl.tree_based_models.trainer import \ createFeatureOrderProtectedTrainer return createFeatureOrderProtectedTrainer(cls=trainer_cls, model=model, data=data, device=device, config=config, monitor=monitor) elif protect_object in ['grad_and_hess']: from federatedscope.vertical_fl.tree_based_models.trainer import \ createLabelProtectedTrainer return createLabelProtectedTrainer(cls=trainer_cls, model=model, data=data, device=device, config=config, monitor=monitor) else: raise ValueError def bucketize(feature_order, bucket_size, bucket_num): if isinstance(bucket_size, int): remainder = len(feature_order) - bucket_size * bucket_num bucket_size = [bucket_size for _ in range(bucket_num)] if remainder > 0: selected_idx = np.random.choice(a=bucket_num, size=remainder, replace=False) for each in selected_idx: bucket_size[each] += 1 elif remainder < 0: selected_idx = np.random.choice(a=bucket_num, size=-remainder, replace=False) for each in selected_idx: bucket_size[each] -= 1 bucketized_feature_order = list() start = 0 for each_bucket_size in bucket_size: end = start + each_bucket_size bucketized_feature_order.append(feature_order[start:end]) start = end return bucketized_feature_order