349 lines
14 KiB
Python
349 lines
14 KiB
Python
import numpy as np
|
|
import logging
|
|
from collections import deque
|
|
|
|
from federatedscope.vertical_fl.dataloader.utils import VerticalDataSampler
|
|
from federatedscope.vertical_fl.loss.utils import get_vertical_loss
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class VerticalTrainer(object):
|
|
def __init__(self, model, data, device, config, monitor):
|
|
self.model = model
|
|
self.data = data
|
|
self.device = device
|
|
self.cfg = config
|
|
self.monitor = monitor
|
|
|
|
self.merged_feature_order = None
|
|
self.client_feature_order = None
|
|
self.complete_feature_order_info = None
|
|
self.client_feature_num = list()
|
|
self.extra_info = None
|
|
self.client_extra_info = None
|
|
self.batch_x = None
|
|
self.batch_y = None
|
|
self.batch_y_hat = None
|
|
self.batch_z = 0
|
|
|
|
def _init_for_train(self):
|
|
self.eta = self.cfg.train.optimizer.eta
|
|
self.dataloader = VerticalDataSampler(
|
|
data=self.data['train'],
|
|
use_full_trainset=True,
|
|
feature_frac=self.cfg.vertical.feature_subsample_ratio)
|
|
self.criterion = get_vertical_loss(loss_type=self.cfg.criterion.type,
|
|
model_type=self.cfg.model.type)
|
|
|
|
def prepare_for_train(self):
|
|
if self.dataloader.use_full_trainset:
|
|
complete_feature_order_info = self._get_feature_order_info(
|
|
self.data['train']['x'])
|
|
self.complete_feature_order_info = complete_feature_order_info
|
|
else:
|
|
self.complete_feature_order_info = None
|
|
|
|
def fetch_train_data(self, index=None):
|
|
# Clear the variables for last training round
|
|
self.client_feature_num.clear()
|
|
|
|
# Fetch new data
|
|
batch_index, self.batch_x, self.batch_y = self.dataloader.sample_data(
|
|
sample_size=self.cfg.dataloader.batch_size, index=index)
|
|
feature_index, self.batch_x = self.dataloader.sample_feature(
|
|
self.batch_x)
|
|
|
|
# convert 'range' to 'list'
|
|
# to support gRPC protocols in distributed mode
|
|
batch_index = list(batch_index)
|
|
|
|
# If the complete trainset is used, we only need to get the slices
|
|
# from the complete feature order info according to the feature index,
|
|
# rather than re-ordering the instance
|
|
if self.dataloader.use_full_trainset:
|
|
assert self.complete_feature_order_info is not None
|
|
feature_order_info = dict()
|
|
for key in self.complete_feature_order_info:
|
|
if isinstance(self.complete_feature_order_info[key],
|
|
list) or isinstance(
|
|
self.complete_feature_order_info[key],
|
|
np.ndarray):
|
|
feature_order_info[key] = [
|
|
self.complete_feature_order_info[key][_index]
|
|
for _index in feature_index
|
|
]
|
|
else:
|
|
feature_order_info[key] = self.complete_feature_order_info[
|
|
key]
|
|
else:
|
|
feature_order_info = self._get_feature_order_info(self.batch_x)
|
|
|
|
if 'raw_feature_order' in feature_order_info:
|
|
# When applying protect method, the raw (real) feature order might
|
|
# be different from the shared feature order
|
|
self.client_feature_order = feature_order_info['raw_feature_order']
|
|
feature_order_info.pop('raw_feature_order')
|
|
else:
|
|
self.client_feature_order = feature_order_info['feature_order']
|
|
self.client_extra_info = feature_order_info.get('extra_info', None)
|
|
|
|
return batch_index, feature_order_info
|
|
|
|
def train(self, training_info=None, tree_num=0, node_num=None):
|
|
# Start to build a tree
|
|
if node_num is None:
|
|
if training_info is not None and \
|
|
self.cfg.vertical.mode == 'feature_gathering':
|
|
self.merged_feature_order, self.extra_info = \
|
|
self._parse_training_info(training_info)
|
|
return self._compute_for_root(tree_num=tree_num)
|
|
# Continue training
|
|
else:
|
|
return self._compute_for_node(tree_num, node_num)
|
|
|
|
def get_abs_feature_idx(self, rel_feature_idx):
|
|
if self.dataloader.selected_feature_index is None:
|
|
return rel_feature_idx
|
|
else:
|
|
return self.dataloader.selected_feature_index[rel_feature_idx]
|
|
|
|
def get_feature_value(self, feature_idx, value_idx):
|
|
assert self.batch_x is not None
|
|
|
|
instance_idx = self.client_feature_order[feature_idx][value_idx]
|
|
return self.batch_x[instance_idx, feature_idx]
|
|
|
|
def _predict(self, tree_num):
|
|
self._compute_weight(tree_num, node_num=0)
|
|
|
|
def _parse_training_info(self, feature_order_info):
|
|
client_ids = list(feature_order_info.keys())
|
|
client_ids = sorted(client_ids)
|
|
merged_feature_order = list()
|
|
for each_client in client_ids:
|
|
_feature_order = feature_order_info[each_client]['feature_order']
|
|
merged_feature_order.append(_feature_order)
|
|
self.client_feature_num.append(len(_feature_order))
|
|
merged_feature_order = np.concatenate(merged_feature_order)
|
|
|
|
# TODO: different extra_info for different clients
|
|
extra_info = feature_order_info[client_ids[0]].get('extra_info', None)
|
|
if extra_info is not None:
|
|
merged_extra_info = dict()
|
|
for each_key in extra_info.keys():
|
|
merged_extra_info[each_key] = np.concatenate([
|
|
feature_order_info[idx]['extra_info'][each_key]
|
|
for idx in client_ids
|
|
])
|
|
else:
|
|
merged_extra_info = None
|
|
|
|
return merged_feature_order, merged_extra_info
|
|
|
|
def _get_feature_order_info(self, data):
|
|
num_of_feature = data.shape[1]
|
|
feature_order = [0] * num_of_feature
|
|
for i in range(num_of_feature):
|
|
feature_order[i] = data[:, i].argsort()
|
|
return {'feature_order': feature_order}
|
|
|
|
def _get_ordered_gh(self,
|
|
tree_num,
|
|
node_num,
|
|
feature_idx,
|
|
grad=None,
|
|
hess=None,
|
|
indicator=None,
|
|
label=None):
|
|
order = self.merged_feature_order[feature_idx]
|
|
if grad is not None:
|
|
ordered_g = np.asarray(grad)[order]
|
|
elif self.model[tree_num][node_num].grad is not None:
|
|
ordered_g = self.model[tree_num][node_num].grad[order]
|
|
else:
|
|
ordered_g = None
|
|
|
|
if hess is not None:
|
|
ordered_h = np.asarray(hess)[order]
|
|
elif self.model[tree_num][node_num].hess is not None:
|
|
ordered_h = self.model[tree_num][node_num].hess[order]
|
|
else:
|
|
ordered_h = None
|
|
|
|
if indicator is not None:
|
|
ordered_indicator = np.asarray(indicator)[order]
|
|
elif self.model[tree_num][node_num].indicator is not None:
|
|
ordered_indicator = self.model[tree_num][node_num].indicator[order]
|
|
else:
|
|
ordered_indicator = None
|
|
|
|
if label is not None:
|
|
ordered_label = np.asarray(label)[order]
|
|
elif self.model[tree_num][node_num].label is not None:
|
|
ordered_label = self.model[tree_num][node_num].label[order]
|
|
else:
|
|
ordered_label = None
|
|
|
|
return ordered_g, ordered_h, ordered_indicator, ordered_label
|
|
|
|
def _get_best_gain(self,
|
|
tree_num,
|
|
node_num,
|
|
grad=None,
|
|
hess=None,
|
|
indicator=None):
|
|
best_gain = 0
|
|
split_ref = {'feature_idx': None, 'value_idx': None}
|
|
|
|
if self.merged_feature_order is None:
|
|
self.merged_feature_order = self.client_feature_order
|
|
if self.extra_info is None:
|
|
self.extra_info = self.client_extra_info
|
|
|
|
feature_num = len(self.merged_feature_order)
|
|
split_position = None
|
|
if self.extra_info is not None:
|
|
split_position = self.extra_info.get('split_position', None)
|
|
|
|
if self.model[tree_num][node_num].indicator is not None:
|
|
activate_idx = [
|
|
np.nonzero(self.model[tree_num][node_num].indicator[order])[0]
|
|
for order in self.merged_feature_order
|
|
]
|
|
else:
|
|
activate_idx = [
|
|
np.arange(self.batch_x.shape[0])
|
|
for _ in self.merged_feature_order
|
|
]
|
|
|
|
activate_idx = np.asarray(activate_idx)
|
|
if split_position is None:
|
|
# The left/right sub-tree cannot be empty
|
|
split_position = activate_idx[:, 1:]
|
|
|
|
for feature_idx in range(feature_num):
|
|
ordered_g, ordered_h, ordered_indicator, ordered_label =\
|
|
self._get_ordered_gh(tree_num,
|
|
node_num,
|
|
feature_idx,
|
|
grad,
|
|
hess,
|
|
indicator,
|
|
label=None)
|
|
order = self.merged_feature_order[feature_idx]
|
|
for value_idx in split_position[feature_idx]:
|
|
if self.model[tree_num].check_empty_child(
|
|
node_num, value_idx, order):
|
|
continue
|
|
gain = self.model[tree_num].cal_gain(ordered_g, ordered_h,
|
|
value_idx,
|
|
ordered_indicator)
|
|
|
|
if gain > best_gain:
|
|
best_gain = gain
|
|
split_ref['feature_idx'] = feature_idx
|
|
split_ref['value_idx'] = value_idx
|
|
|
|
return best_gain > 0, split_ref, best_gain
|
|
|
|
def _compute_for_root(self, tree_num):
|
|
if self.batch_y_hat is None:
|
|
# Assign a random predictions when tree_num = 0
|
|
self.batch_y_hat = [
|
|
np.random.uniform(low=0.0, high=1.0, size=len(self.batch_y))
|
|
]
|
|
g, h = self.criterion.get_grad_and_hess(self.batch_y, self.batch_y_hat)
|
|
node_num = 0
|
|
self.model[tree_num][node_num].grad = g
|
|
self.model[tree_num][node_num].hess = h
|
|
self.model[tree_num][node_num].indicator = np.ones(len(self.batch_y))
|
|
return self._compute_for_node(tree_num, node_num=node_num)
|
|
|
|
def _compute_for_node(self, tree_num, node_num):
|
|
|
|
# All the nodes have been traversed
|
|
if node_num >= 2**self.model.max_depth - 1:
|
|
self._predict(tree_num)
|
|
return 'train_finish', None
|
|
elif self.model[tree_num][node_num].status == 'off':
|
|
return self._compute_for_node(tree_num, node_num + 1)
|
|
# The leaf node
|
|
elif node_num >= 2**(self.model.max_depth - 1) - 1:
|
|
self._set_weight_and_status(tree_num, node_num)
|
|
return self._compute_for_node(tree_num, node_num + 1)
|
|
# Calculate best gain
|
|
else:
|
|
if self.cfg.vertical.mode == 'feature_gathering':
|
|
improved_flag, split_ref, _ = self._get_best_gain(
|
|
tree_num, node_num)
|
|
if improved_flag:
|
|
split_feature = self.merged_feature_order[
|
|
split_ref['feature_idx']]
|
|
left_child, right_child = self.get_children_indicator(
|
|
value_idx=split_ref['value_idx'],
|
|
split_feature=split_feature)
|
|
self.update_child(tree_num, node_num, left_child,
|
|
right_child)
|
|
results = (split_ref, tree_num, node_num)
|
|
return 'call_for_node_split', results
|
|
else:
|
|
self._set_weight_and_status(tree_num, node_num)
|
|
return self._compute_for_node(tree_num, node_num + 1)
|
|
elif self.cfg.vertical.mode == 'label_scattering':
|
|
results = (self.model[tree_num][node_num].grad,
|
|
self.model[tree_num][node_num].hess,
|
|
self.model[tree_num][node_num].indicator, tree_num,
|
|
node_num)
|
|
return 'call_for_local_gain', results
|
|
|
|
def _compute_weight(self, tree_num, node_num):
|
|
if node_num >= 2**self.model.max_depth - 1:
|
|
if tree_num == 0:
|
|
self.batch_y_hat = [self.batch_z]
|
|
else:
|
|
self.batch_y_hat.append(self.batch_z)
|
|
self.batch_z = 0
|
|
|
|
else:
|
|
if self.model[tree_num][node_num].weight:
|
|
self.batch_z += self.model[tree_num][
|
|
node_num].weight * self.model[tree_num][
|
|
node_num].indicator * self.eta
|
|
self._compute_weight(tree_num, node_num + 1)
|
|
|
|
def _set_weight_and_status(self, tree_num, node_num):
|
|
self.model[tree_num].set_weight(node_num)
|
|
|
|
queue = deque()
|
|
queue.append(node_num)
|
|
while len(queue) > 0:
|
|
cur_node = queue.popleft()
|
|
self.model[tree_num].set_status(cur_node, status='off')
|
|
if 2 * cur_node + 2 <= 2**self.model[tree_num].max_depth - 1:
|
|
queue.append(2 * cur_node + 1)
|
|
queue.append(2 * cur_node + 2)
|
|
|
|
def get_children_indicator(self, value_idx, split_feature):
|
|
left_child = np.zeros(self.batch_x.shape[0])
|
|
for x in range(value_idx):
|
|
left_child[split_feature[x]] = 1
|
|
right_child = np.ones(self.batch_x.shape[0]) - left_child
|
|
|
|
return left_child, right_child
|
|
|
|
def update_child(self, tree_num, node_num, left_child, right_child):
|
|
self.model[tree_num].update_child(node_num, left_child, right_child)
|
|
|
|
def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None):
|
|
client_has_max_gain = None
|
|
max_gain = None
|
|
for client_id, local_gain in msg.items():
|
|
gain, improved_flag, _ = local_gain
|
|
if improved_flag:
|
|
if max_gain is None or gain > max_gain:
|
|
max_gain = gain
|
|
client_has_max_gain = client_id
|
|
|
|
return max_gain, client_has_max_gain, None
|