114 lines
4.3 KiB
Python
114 lines
4.3 KiB
Python
import numpy as np
|
|
import logging
|
|
|
|
from federatedscope.core.workers import Client
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.vertical_fl.dataloader.utils import batch_iter
|
|
|
|
|
|
class vFLClient(Client):
|
|
"""
|
|
The client class for vertical FL, which customizes the handled
|
|
functions. Please refer to the tutorial for more details about the
|
|
implementation algorithm
|
|
Implementation of Vertical FL refer to `Private federated learning on
|
|
vertically partitioned data via entity resolution and additively
|
|
homomorphic encryption` [Hardy, et al., 2017]
|
|
(https://arxiv.org/abs/1711.10677)
|
|
"""
|
|
def __init__(self,
|
|
ID=-1,
|
|
server_id=None,
|
|
state=-1,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
device='cpu',
|
|
strategy=None,
|
|
*args,
|
|
**kwargs):
|
|
|
|
super(vFLClient,
|
|
self).__init__(ID, server_id, state, config, data, model, device,
|
|
strategy, *args, **kwargs)
|
|
self.data = data
|
|
self.public_key = None
|
|
self.theta = None
|
|
self.batch_index = None
|
|
self._init_data_related_var()
|
|
|
|
self.register_handlers('public_keys',
|
|
self.callback_funcs_for_public_keys)
|
|
self.register_handlers('model_para',
|
|
self.callback_funcs_for_model_para)
|
|
self.register_handlers('encryped_gradient_u',
|
|
self.callback_funcs_for_encryped_gradient_u)
|
|
self.register_handlers('encryped_gradient_v',
|
|
self.callback_funcs_for_encryped_gradient_v)
|
|
|
|
def _init_data_related_var(self):
|
|
self.own_label = ('y' in self.data['train'])
|
|
self.dataloader = batch_iter(self.data['train'],
|
|
self._cfg.dataloader.batch_size,
|
|
shuffled=True)
|
|
|
|
def sample_data(self, index=None):
|
|
if index is None:
|
|
assert self.own_label
|
|
return next(self.dataloader)
|
|
else:
|
|
return self.data['train']['x'][index]
|
|
|
|
def callback_funcs_for_public_keys(self, message: Message):
|
|
self.public_key = message.content
|
|
|
|
def callback_funcs_for_model_para(self, message: Message):
|
|
self.theta = message.content
|
|
if self.own_label:
|
|
index, input_x, input_y = self.sample_data()
|
|
self.batch_index = index
|
|
u_A = 0.25 * np.matmul(input_x, self.theta) - 0.5 * input_y
|
|
en_u_A = [self.public_key.encrypt(x) for x in u_A]
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='encryped_gradient_u',
|
|
sender=self.ID,
|
|
receiver=[
|
|
each for each in self.comm_manager.neighbors
|
|
if each != self.server_id
|
|
],
|
|
state=self.state,
|
|
content=(self.batch_index, en_u_A)))
|
|
|
|
def callback_funcs_for_encryped_gradient_u(self, message: Message):
|
|
index, en_u_A = message.content
|
|
self.batch_index = index
|
|
input_x = self.sample_data(index=self.batch_index)
|
|
u_B = 0.25 * np.matmul(input_x, self.theta)
|
|
en_u_B = [self.public_key.encrypt(x) for x in u_B]
|
|
en_u = np.expand_dims([sum(x) for x in zip(en_u_A, en_u_B)], -1)
|
|
en_v_B = en_u * input_x
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='encryped_gradient_v',
|
|
sender=self.ID,
|
|
receiver=[
|
|
each for each in self.comm_manager.neighbors
|
|
if each != self.server_id
|
|
],
|
|
state=self.state,
|
|
content=(en_u, en_v_B)))
|
|
|
|
def callback_funcs_for_encryped_gradient_v(self, message: Message):
|
|
en_u, en_v_B = message.content
|
|
input_x = self.sample_data(index=self.batch_index)
|
|
en_v_A = en_u * input_x
|
|
en_v = np.concatenate([en_v_B, en_v_A], axis=-1)
|
|
|
|
self.comm_manager.send(
|
|
Message(msg_type='encryped_gradient',
|
|
sender=self.ID,
|
|
receiver=[self.server_id],
|
|
state=self.state,
|
|
content=(len(input_x), en_v)))
|