FS-TFP/federatedscope/vertical_fl/tree_based_models/trainer/trainer.py

349 lines
14 KiB
Python

import numpy as np
import logging
from collections import deque
from federatedscope.vertical_fl.dataloader.utils import VerticalDataSampler
from federatedscope.vertical_fl.loss.utils import get_vertical_loss
logger = logging.getLogger(__name__)
class VerticalTrainer(object):
def __init__(self, model, data, device, config, monitor):
self.model = model
self.data = data
self.device = device
self.cfg = config
self.monitor = monitor
self.merged_feature_order = None
self.client_feature_order = None
self.complete_feature_order_info = None
self.client_feature_num = list()
self.extra_info = None
self.client_extra_info = None
self.batch_x = None
self.batch_y = None
self.batch_y_hat = None
self.batch_z = 0
def _init_for_train(self):
self.eta = self.cfg.train.optimizer.eta
self.dataloader = VerticalDataSampler(
data=self.data['train'],
use_full_trainset=True,
feature_frac=self.cfg.vertical.feature_subsample_ratio)
self.criterion = get_vertical_loss(loss_type=self.cfg.criterion.type,
model_type=self.cfg.model.type)
def prepare_for_train(self):
if self.dataloader.use_full_trainset:
complete_feature_order_info = self._get_feature_order_info(
self.data['train']['x'])
self.complete_feature_order_info = complete_feature_order_info
else:
self.complete_feature_order_info = None
def fetch_train_data(self, index=None):
# Clear the variables for last training round
self.client_feature_num.clear()
# Fetch new data
batch_index, self.batch_x, self.batch_y = self.dataloader.sample_data(
sample_size=self.cfg.dataloader.batch_size, index=index)
feature_index, self.batch_x = self.dataloader.sample_feature(
self.batch_x)
# convert 'range' to 'list'
# to support gRPC protocols in distributed mode
batch_index = list(batch_index)
# If the complete trainset is used, we only need to get the slices
# from the complete feature order info according to the feature index,
# rather than re-ordering the instance
if self.dataloader.use_full_trainset:
assert self.complete_feature_order_info is not None
feature_order_info = dict()
for key in self.complete_feature_order_info:
if isinstance(self.complete_feature_order_info[key],
list) or isinstance(
self.complete_feature_order_info[key],
np.ndarray):
feature_order_info[key] = [
self.complete_feature_order_info[key][_index]
for _index in feature_index
]
else:
feature_order_info[key] = self.complete_feature_order_info[
key]
else:
feature_order_info = self._get_feature_order_info(self.batch_x)
if 'raw_feature_order' in feature_order_info:
# When applying protect method, the raw (real) feature order might
# be different from the shared feature order
self.client_feature_order = feature_order_info['raw_feature_order']
feature_order_info.pop('raw_feature_order')
else:
self.client_feature_order = feature_order_info['feature_order']
self.client_extra_info = feature_order_info.get('extra_info', None)
return batch_index, feature_order_info
def train(self, training_info=None, tree_num=0, node_num=None):
# Start to build a tree
if node_num is None:
if training_info is not None and \
self.cfg.vertical.mode == 'feature_gathering':
self.merged_feature_order, self.extra_info = \
self._parse_training_info(training_info)
return self._compute_for_root(tree_num=tree_num)
# Continue training
else:
return self._compute_for_node(tree_num, node_num)
def get_abs_feature_idx(self, rel_feature_idx):
if self.dataloader.selected_feature_index is None:
return rel_feature_idx
else:
return self.dataloader.selected_feature_index[rel_feature_idx]
def get_feature_value(self, feature_idx, value_idx):
assert self.batch_x is not None
instance_idx = self.client_feature_order[feature_idx][value_idx]
return self.batch_x[instance_idx, feature_idx]
def _predict(self, tree_num):
self._compute_weight(tree_num, node_num=0)
def _parse_training_info(self, feature_order_info):
client_ids = list(feature_order_info.keys())
client_ids = sorted(client_ids)
merged_feature_order = list()
for each_client in client_ids:
_feature_order = feature_order_info[each_client]['feature_order']
merged_feature_order.append(_feature_order)
self.client_feature_num.append(len(_feature_order))
merged_feature_order = np.concatenate(merged_feature_order)
# TODO: different extra_info for different clients
extra_info = feature_order_info[client_ids[0]].get('extra_info', None)
if extra_info is not None:
merged_extra_info = dict()
for each_key in extra_info.keys():
merged_extra_info[each_key] = np.concatenate([
feature_order_info[idx]['extra_info'][each_key]
for idx in client_ids
])
else:
merged_extra_info = None
return merged_feature_order, merged_extra_info
def _get_feature_order_info(self, data):
num_of_feature = data.shape[1]
feature_order = [0] * num_of_feature
for i in range(num_of_feature):
feature_order[i] = data[:, i].argsort()
return {'feature_order': feature_order}
def _get_ordered_gh(self,
tree_num,
node_num,
feature_idx,
grad=None,
hess=None,
indicator=None,
label=None):
order = self.merged_feature_order[feature_idx]
if grad is not None:
ordered_g = np.asarray(grad)[order]
elif self.model[tree_num][node_num].grad is not None:
ordered_g = self.model[tree_num][node_num].grad[order]
else:
ordered_g = None
if hess is not None:
ordered_h = np.asarray(hess)[order]
elif self.model[tree_num][node_num].hess is not None:
ordered_h = self.model[tree_num][node_num].hess[order]
else:
ordered_h = None
if indicator is not None:
ordered_indicator = np.asarray(indicator)[order]
elif self.model[tree_num][node_num].indicator is not None:
ordered_indicator = self.model[tree_num][node_num].indicator[order]
else:
ordered_indicator = None
if label is not None:
ordered_label = np.asarray(label)[order]
elif self.model[tree_num][node_num].label is not None:
ordered_label = self.model[tree_num][node_num].label[order]
else:
ordered_label = None
return ordered_g, ordered_h, ordered_indicator, ordered_label
def _get_best_gain(self,
tree_num,
node_num,
grad=None,
hess=None,
indicator=None):
best_gain = 0
split_ref = {'feature_idx': None, 'value_idx': None}
if self.merged_feature_order is None:
self.merged_feature_order = self.client_feature_order
if self.extra_info is None:
self.extra_info = self.client_extra_info
feature_num = len(self.merged_feature_order)
split_position = None
if self.extra_info is not None:
split_position = self.extra_info.get('split_position', None)
if self.model[tree_num][node_num].indicator is not None:
activate_idx = [
np.nonzero(self.model[tree_num][node_num].indicator[order])[0]
for order in self.merged_feature_order
]
else:
activate_idx = [
np.arange(self.batch_x.shape[0])
for _ in self.merged_feature_order
]
activate_idx = np.asarray(activate_idx)
if split_position is None:
# The left/right sub-tree cannot be empty
split_position = activate_idx[:, 1:]
for feature_idx in range(feature_num):
ordered_g, ordered_h, ordered_indicator, ordered_label =\
self._get_ordered_gh(tree_num,
node_num,
feature_idx,
grad,
hess,
indicator,
label=None)
order = self.merged_feature_order[feature_idx]
for value_idx in split_position[feature_idx]:
if self.model[tree_num].check_empty_child(
node_num, value_idx, order):
continue
gain = self.model[tree_num].cal_gain(ordered_g, ordered_h,
value_idx,
ordered_indicator)
if gain > best_gain:
best_gain = gain
split_ref['feature_idx'] = feature_idx
split_ref['value_idx'] = value_idx
return best_gain > 0, split_ref, best_gain
def _compute_for_root(self, tree_num):
if self.batch_y_hat is None:
# Assign a random predictions when tree_num = 0
self.batch_y_hat = [
np.random.uniform(low=0.0, high=1.0, size=len(self.batch_y))
]
g, h = self.criterion.get_grad_and_hess(self.batch_y, self.batch_y_hat)
node_num = 0
self.model[tree_num][node_num].grad = g
self.model[tree_num][node_num].hess = h
self.model[tree_num][node_num].indicator = np.ones(len(self.batch_y))
return self._compute_for_node(tree_num, node_num=node_num)
def _compute_for_node(self, tree_num, node_num):
# All the nodes have been traversed
if node_num >= 2**self.model.max_depth - 1:
self._predict(tree_num)
return 'train_finish', None
elif self.model[tree_num][node_num].status == 'off':
return self._compute_for_node(tree_num, node_num + 1)
# The leaf node
elif node_num >= 2**(self.model.max_depth - 1) - 1:
self._set_weight_and_status(tree_num, node_num)
return self._compute_for_node(tree_num, node_num + 1)
# Calculate best gain
else:
if self.cfg.vertical.mode == 'feature_gathering':
improved_flag, split_ref, _ = self._get_best_gain(
tree_num, node_num)
if improved_flag:
split_feature = self.merged_feature_order[
split_ref['feature_idx']]
left_child, right_child = self.get_children_indicator(
value_idx=split_ref['value_idx'],
split_feature=split_feature)
self.update_child(tree_num, node_num, left_child,
right_child)
results = (split_ref, tree_num, node_num)
return 'call_for_node_split', results
else:
self._set_weight_and_status(tree_num, node_num)
return self._compute_for_node(tree_num, node_num + 1)
elif self.cfg.vertical.mode == 'label_scattering':
results = (self.model[tree_num][node_num].grad,
self.model[tree_num][node_num].hess,
self.model[tree_num][node_num].indicator, tree_num,
node_num)
return 'call_for_local_gain', results
def _compute_weight(self, tree_num, node_num):
if node_num >= 2**self.model.max_depth - 1:
if tree_num == 0:
self.batch_y_hat = [self.batch_z]
else:
self.batch_y_hat.append(self.batch_z)
self.batch_z = 0
else:
if self.model[tree_num][node_num].weight:
self.batch_z += self.model[tree_num][
node_num].weight * self.model[tree_num][
node_num].indicator * self.eta
self._compute_weight(tree_num, node_num + 1)
def _set_weight_and_status(self, tree_num, node_num):
self.model[tree_num].set_weight(node_num)
queue = deque()
queue.append(node_num)
while len(queue) > 0:
cur_node = queue.popleft()
self.model[tree_num].set_status(cur_node, status='off')
if 2 * cur_node + 2 <= 2**self.model[tree_num].max_depth - 1:
queue.append(2 * cur_node + 1)
queue.append(2 * cur_node + 2)
def get_children_indicator(self, value_idx, split_feature):
left_child = np.zeros(self.batch_x.shape[0])
for x in range(value_idx):
left_child[split_feature[x]] = 1
right_child = np.ones(self.batch_x.shape[0]) - left_child
return left_child, right_child
def update_child(self, tree_num, node_num, left_child, right_child):
self.model[tree_num].update_child(node_num, left_child, right_child)
def get_best_gain_from_msg(self, msg, tree_num=None, node_num=None):
client_has_max_gain = None
max_gain = None
for client_id, local_gain in msg.items():
gain, improved_flag, _ = local_gain
if improved_flag:
if max_gain is None or gain > max_gain:
max_gain = gain
client_has_max_gain = client_id
return max_gain, client_has_max_gain, None