FS-TFP/federatedscope/vertical_fl/tree_based_models/worker/TreeClient.py

145 lines
5.4 KiB
Python

import numpy as np
import logging
import copy
from federatedscope.core.workers import Client
from federatedscope.core.message import Message
from federatedscope.vertical_fl.Paillier import \
abstract_paillier
logger = logging.getLogger(__name__)
class TreeClient(Client):
def __init__(self,
ID=-1,
server_id=None,
state=0,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
*args,
**kwargs):
super(TreeClient,
self).__init__(ID, server_id, state, config, data, model, device,
strategy, *args, **kwargs)
self.data = data
self.own_label = ('y' in data['train'])
self.msg_buffer = {'train': {}, 'eval': {}}
self.client_num = self._cfg.federate.client_num
if self._cfg.vertical.eval_protection == 'he':
keys = abstract_paillier.generate_paillier_keypair(
n_length=self._cfg.vertical.key_size)
self.public_key, self.private_key = keys
self.feature_order = None
self.merged_feature_order = None
self.feature_partition = np.diff(self._cfg.vertical.dims, prepend=0)
self.total_num_of_feature = self._cfg.vertical.dims[-1]
self.num_of_feature = self.feature_partition[self.ID - 1]
self.feature_importance = [0] * self.num_of_feature
self._init_data_related_var()
self.register_handlers('model_para', self.callback_func_for_model_para)
self.register_handlers('data_sample',
self.callback_func_for_data_sample)
self.register_handlers('training_info',
self.callback_func_for_training_info)
self.register_handlers('finish', self.callback_func_for_finish)
def train(self, tree_num, node_num=None, training_info=None):
raise NotImplementedError
def eval(self, tree_num):
raise NotImplementedError
def _init_data_related_var(self):
self.trainer._init_for_train()
self.test_x = None
self.test_y = None
# all clients receive model para, and initial a tree list,
# each contains self.num_of_trees trees
# label-owner initials y_hat
# label-owner sends "sample data" to others
def callback_func_for_model_para(self, message: Message):
self.state = message.state
self.trainer.prepare_for_train()
if self.own_label:
batch_index, feature_order_info = self.trainer.fetch_train_data()
self.start_a_new_training_round(batch_index,
feature_order_info,
tree_num=0)
# other clients receive the data-sample information
def callback_func_for_data_sample(self, message: Message):
self.state = message.state
batch_index, sender = message.content, message.sender
_, feature_order_info = self.trainer.fetch_train_data(
index=batch_index)
self.feature_order = feature_order_info['feature_order']
if self._cfg.vertical.mode == 'feature_gathering':
training_info = feature_order_info
elif self._cfg.vertical.mode == 'label_scattering':
training_info = 'dummy_info'
else:
raise TypeError(
f'The expected types of vertical.mode include '
f'["label_scattering", "feature_gathering"], but got '
f'{self._cfg.vertical.mode}.')
self.comm_manager.send(
Message(msg_type='training_info',
sender=self.ID,
state=self.state,
receiver=[sender],
content=training_info))
def callback_func_for_training_info(self, message: Message):
feature_order_info, sender = message.content, message.sender
self.msg_buffer['train'][sender] = feature_order_info
self.check_and_move_on()
def callback_func_for_finish(self, message: Message):
logger.info(
f"================= client {self.ID} received finish message "
f"=================")
# self._monitor.finish_fl()
def start_a_new_training_round(self,
batch_index,
feature_order_info,
tree_num=0):
self.msg_buffer['train'].clear()
self.feature_order = feature_order_info['feature_order']
self.msg_buffer['train'][self.ID] = feature_order_info \
if self._cfg.vertical.mode == 'feature_gathering' else 'dummy_info'
self.state = tree_num
receiver = [
each for each in list(self.comm_manager.neighbors.keys())
if each != self.server_id
]
send_message = Message(msg_type='data_sample',
sender=self.ID,
state=self.state,
receiver=receiver,
content=batch_index)
self.comm_manager.send(send_message)
def check_and_move_on(self):
if len(self.msg_buffer['train']) == self.client_num:
received_training_infos = copy.deepcopy(self.msg_buffer['train'])
self.msg_buffer['train'].clear()
self.train(tree_num=self.state,
training_info=received_training_infos)