import numpy as np import logging import copy from federatedscope.core.workers import Client from federatedscope.core.message import Message from federatedscope.vertical_fl.Paillier import \ abstract_paillier logger = logging.getLogger(__name__) class TreeClient(Client): def __init__(self, ID=-1, server_id=None, state=0, config=None, data=None, model=None, device='cpu', strategy=None, *args, **kwargs): super(TreeClient, self).__init__(ID, server_id, state, config, data, model, device, strategy, *args, **kwargs) self.data = data self.own_label = ('y' in data['train']) self.msg_buffer = {'train': {}, 'eval': {}} self.client_num = self._cfg.federate.client_num if self._cfg.vertical.eval_protection == 'he': keys = abstract_paillier.generate_paillier_keypair( n_length=self._cfg.vertical.key_size) self.public_key, self.private_key = keys self.feature_order = None self.merged_feature_order = None self.feature_partition = np.diff(self._cfg.vertical.dims, prepend=0) self.total_num_of_feature = self._cfg.vertical.dims[-1] self.num_of_feature = self.feature_partition[self.ID - 1] self.feature_importance = [0] * self.num_of_feature self._init_data_related_var() self.register_handlers('model_para', self.callback_func_for_model_para) self.register_handlers('data_sample', self.callback_func_for_data_sample) self.register_handlers('training_info', self.callback_func_for_training_info) self.register_handlers('finish', self.callback_func_for_finish) def train(self, tree_num, node_num=None, training_info=None): raise NotImplementedError def eval(self, tree_num): raise NotImplementedError def _init_data_related_var(self): self.trainer._init_for_train() self.test_x = None self.test_y = None # all clients receive model para, and initial a tree list, # each contains self.num_of_trees trees # label-owner initials y_hat # label-owner sends "sample data" to others def callback_func_for_model_para(self, message: Message): self.state = message.state self.trainer.prepare_for_train() if self.own_label: batch_index, feature_order_info = self.trainer.fetch_train_data() self.start_a_new_training_round(batch_index, feature_order_info, tree_num=0) # other clients receive the data-sample information def callback_func_for_data_sample(self, message: Message): self.state = message.state batch_index, sender = message.content, message.sender _, feature_order_info = self.trainer.fetch_train_data( index=batch_index) self.feature_order = feature_order_info['feature_order'] if self._cfg.vertical.mode == 'feature_gathering': training_info = feature_order_info elif self._cfg.vertical.mode == 'label_scattering': training_info = 'dummy_info' else: raise TypeError( f'The expected types of vertical.mode include ' f'["label_scattering", "feature_gathering"], but got ' f'{self._cfg.vertical.mode}.') self.comm_manager.send( Message(msg_type='training_info', sender=self.ID, state=self.state, receiver=[sender], content=training_info)) def callback_func_for_training_info(self, message: Message): feature_order_info, sender = message.content, message.sender self.msg_buffer['train'][sender] = feature_order_info self.check_and_move_on() def callback_func_for_finish(self, message: Message): logger.info( f"================= client {self.ID} received finish message " f"=================") # self._monitor.finish_fl() def start_a_new_training_round(self, batch_index, feature_order_info, tree_num=0): self.msg_buffer['train'].clear() self.feature_order = feature_order_info['feature_order'] self.msg_buffer['train'][self.ID] = feature_order_info \ if self._cfg.vertical.mode == 'feature_gathering' else 'dummy_info' self.state = tree_num receiver = [ each for each in list(self.comm_manager.neighbors.keys()) if each != self.server_id ] send_message = Message(msg_type='data_sample', sender=self.ID, state=self.state, receiver=receiver, content=batch_index) self.comm_manager.send(send_message) def check_and_move_on(self): if len(self.msg_buffer['train']) == self.client_num: received_training_infos = copy.deepcopy(self.msg_buffer['train']) self.msg_buffer['train'].clear() self.train(tree_num=self.state, training_info=received_training_infos)