import numpy as np import logging from collections import deque from federatedscope.vertical_fl.dataloader.utils import VerticalDataSampler from federatedscope.vertical_fl.loss.utils import get_vertical_loss logger = logging.getLogger(__name__) class VerticalTrainer(object): def __init__(self, model, data, device, config, monitor): self.model = model self.data = data self.device = device self.cfg = config self.monitor = monitor self.merged_feature_order = None self.client_feature_order = None self.complete_feature_order_info = None self.client_feature_num = list() self.extra_info = None self.client_extra_info = None self.batch_x = None self.batch_y = None self.batch_y_hat = None self.batch_z = 0 def _init_for_train(self): self.eta = self.cfg.train.optimizer.eta self.dataloader = VerticalDataSampler( data=self.data['train'], use_full_trainset=True, feature_frac=self.cfg.vertical.feature_subsample_ratio) self.criterion = get_vertical_loss(loss_type=self.cfg.criterion.type, model_type=self.cfg.model.type) def prepare_for_train(self): if self.dataloader.use_full_trainset: complete_feature_order_info = self._get_feature_order_info( self.data['train']['x']) self.complete_feature_order_info = complete_feature_order_info else: self.complete_feature_order_info = None def fetch_train_data(self, index=None): # Clear the variables for last training round self.client_feature_num.clear() # Fetch new data batch_index, self.batch_x, self.batch_y = self.dataloader.sample_data( sample_size=self.cfg.dataloader.batch_size, index=index) feature_index, self.batch_x = self.dataloader.sample_feature( self.batch_x) # convert 'range' to 'list' # to support gRPC protocols in distributed mode batch_index = list(batch_index) # If the complete trainset is used, we only need to get the slices # from the complete feature order info according to the feature index, # rather than re-ordering the instance if self.dataloader.use_full_trainset: assert self.complete_feature_order_info is not None feature_order_info = dict() for key in self.complete_feature_order_info: if isinstance(self.complete_feature_order_info[key], list) or isinstance( self.complete_feature_order_info[key], np.ndarray): feature_order_info[key] = [ self.complete_feature_order_info[key][_index] for _index in feature_index ] else: feature_order_info[key] = self.complete_feature_order_info[ key] else: feature_order_info = self._get_feature_order_info(self.batch_x) if 'raw_feature_order' in feature_order_info: # When applying protect method, the raw (real) feature order might # be different from the shared feature order self.client_feature_order = feature_order_info['raw_feature_order'] feature_order_info.pop('raw_feature_order') else: self.client_feature_order = feature_order_info['feature_order'] self.client_extra_info = feature_order_info.get('extra_info', None) return batch_index, feature_order_info def train(self, training_info=None, tree_num=0, node_num=None): # Start to build a tree if node_num is None: if training_info is not None and \ self.cfg.vertical.mode == 'feature_gathering': self.merged_feature_order, self.extra_info = \ self._parse_training_info(training_info) return self._compute_for_root(tree_num=tree_num) # Continue training else: return self._compute_for_node(tree_num, node_num) def get_abs_feature_idx(self, rel_feature_idx): if self.dataloader.selected_feature_index is None: return rel_feature_idx else: return self.dataloader.selected_feature_index[rel_feature_idx] def get_feature_value(self, feature_idx, value_idx): assert self.batch_x is not None instance_idx = self.client_feature_order[feature_idx][value_idx] return self.batch_x[instance_idx, feature_idx] def _predict(self, tree_num): self._compute_weight(tree_num, node_num=0) def _parse_training_info(self, feature_order_info): client_ids = list(feature_order_info.keys()) client_ids = sorted(client_ids) merged_feature_order = list() for each_client in client_ids: _feature_order = feature_order_info[each_client]['feature_order'] merged_feature_order.append(_feature_order) self.client_feature_num.append(len(_feature_order)) merged_feature_order = np.concatenate(merged_feature_order) # TODO: different extra_info for different clients extra_info = feature_order_info[client_ids[0]].get('extra_info', None) if extra_info is not None: merged_extra_info = dict() for each_key in extra_info.keys(): merged_extra_info[each_key] = np.concatenate([ feature_order_info[idx]['extra_info'][each_key] for idx in client_ids ]) else: merged_extra_info = None return merged_feature_order, merged_extra_info def _get_feature_order_info(self, data): num_of_feature = data.shape[1] feature_order = [0] * num_of_feature for i in range(num_of_feature): feature_order[i] = data[:, i].argsort() return {'feature_order': feature_order} def _get_ordered_gh(self, tree_num, node_num, feature_idx, grad=None, hess=None, indicator=None, label=None): order = self.merged_feature_order[feature_idx] if grad is not None: ordered_g = np.asarray(grad)[order] elif self.model[tree_num][node_num].grad is not None: ordered_g = self.model[tree_num][node_num].grad[order] else: ordered_g = None if hess is not None: ordered_h = np.asarray(hess)[order] elif self.model[tree_num][node_num].hess is not None: ordered_h = self.model[tree_num][node_num].hess[order] else: ordered_h = None if indicator is not None: ordered_indicator = np.asarray(indicator)[order] elif self.model[tree_num][node_num].indicator is not None: ordered_indicator = self.model[tree_num][node_num].indicator[order] else: ordered_indicator = None if label is not None: ordered_label = np.asarray(label)[order] elif self.model[tree_num][node_num].label is not None: ordered_label = self.model[tree_num][node_num].label[order] else: ordered_label = None return ordered_g, ordered_h, ordered_indicator, ordered_label def _get_best_gain(self, tree_num, node_num, grad=None, hess=None, indicator=None): best_gain = 0 split_ref = {'feature_idx': None, 'value_idx': None} if self.merged_feature_order is None: self.merged_feature_order = self.client_feature_order if self.extra_info is None: self.extra_info = self.client_extra_info feature_num = len(self.merged_feature_order) split_position = None if self.extra_info is not None: split_position = self.extra_info.get('split_position', None) if self.model[tree_num][node_num].indicator is not None: activate_idx = [ np.nonzero(self.model[tree_num][node_num].indicator[order])[0] for order in self.merged_feature_order ] else: activate_idx = [ np.arange(self.batch_x.shape[0]) for _ in self.merged_feature_order ] activate_idx = np.asarray(activate_idx) if split_position is None: # The left/right sub-tree cannot be empty split_position = activate_idx[:, 1:] for feature_idx in range(feature_num): ordered_g, ordered_h, ordered_indicator, ordered_label =\ self._get_ordered_gh(tree_num, node_num, feature_idx, grad, hess, indicator, label=None) order = self.merged_feature_order[feature_idx] for value_idx in split_position[feature_idx]: if self.model[tree_num].check_empty_child( node_num, value_idx, order): continue gain = self.model[tree_num].cal_gain(ordered_g, ordered_h, value_idx, ordered_indicator) if gain > best_gain: best_gain = gain split_ref['feature_idx'] = feature_idx split_ref['value_idx'] = value_idx return best_gain > 0, split_ref, best_gain def _compute_for_root(self, tree_num): if self.batch_y_hat is None: # Assign a random predictions when tree_num = 0 self.batch_y_hat = [ np.random.uniform(low=0.0, high=1.0, size=len(self.batch_y)) ] g, h = self.criterion.get_grad_and_hess(self.batch_y, self.batch_y_hat) node_num = 0 self.model[tree_num][node_num].grad = g self.model[tree_num][node_num].hess = h self.model[tree_num][node_num].indicator = np.ones(len(self.batch_y)) return self._compute_for_node(tree_num, node_num=node_num) def _compute_for_node(self, tree_num, node_num): # All the nodes have been traversed if node_num >= 2**self.model.max_depth - 1: self._predict(tree_num) return 'train_finish', None elif self.model[tree_num][node_num].status == 'off': return self._compute_for_node(tree_num, node_num + 1) # The leaf node elif node_num >= 2**(self.model.max_depth - 1) - 1: self._set_weight_and_status(tree_num, node_num) return self._compute_for_node(tree_num, node_num + 1) # Calculate best gain else: if self.cfg.vertical.mode == 'feature_gathering': improved_flag, split_ref, _ = self._get_best_gain( tree_num, node_num) if improved_flag: split_feature = self.merged_feature_order[ split_ref['feature_idx']] left_child, right_child = self.get_children_indicator( value_idx=split_ref['value_idx'], split_feature=split_feature) self.update_child(tree_num, node_num, left_child, right_child) results = (split_ref, tree_num, node_num) return 'call_for_node_split', results else: self._set_weight_and_status(tree_num, node_num) return self._compute_for_node(tree_num, node_num + 1) elif self.cfg.vertical.mode == 'label_scattering': results = (self.model[tree_num][node_num].grad, self.model[tree_num][node_num].hess, self.model[tree_num][node_num].indicator, tree_num, node_num) return 'call_for_local_gain', results def _compute_weight(self, tree_num, node_num): if node_num >= 2**self.model.max_depth - 1: if tree_num == 0: self.batch_y_hat = [self.batch_z] else: self.batch_y_hat.append(self.batch_z) self.batch_z = 0 else: if self.model[tree_num][node_num].weight: self.batch_z += self.model[tree_num][ node_num].weight * self.model[tree_num][ node_num].indicator * self.eta self._compute_weight(tree_num, node_num + 1) def _set_weight_and_status(self, tree_num, node_num): self.model[tree_num].set_weight(node_num) queue = deque() queue.append(node_num) while len(queue) > 0: cur_node = queue.popleft() self.model[tree_num].set_status(cur_node, status='off') if 2 * cur_node + 2 <= 2**self.model[tree_num].max_depth - 1: queue.append(2 * cur_node + 1) queue.append(2 * cur_node + 2) def get_children_indicator(self, value_idx, split_feature): left_child = np.zeros(self.batch_x.shape[0]) for x in range(value_idx): left_child[split_feature[x]] = 1 right_child = np.ones(self.batch_x.shape[0]) - left_child return left_child, right_child def update_child(self, tree_num, node_num, left_child, right_child): self.model[tree_num].update_child(node_num, left_child, right_child) def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None): client_has_max_gain = None max_gain = None for client_id, local_gain in msg.items(): gain, improved_flag, _ = local_gain if improved_flag: if max_gain is None or gain > max_gain: max_gain = gain client_has_max_gain = client_id return max_gain, client_has_max_gain, None