FS-TFP/federatedscope/vertical_fl/linear_model/worker/vertical_client.py

114 lines
4.3 KiB
Python

import numpy as np
import logging
from federatedscope.core.workers import Client
from federatedscope.core.message import Message
from federatedscope.vertical_fl.dataloader.utils import batch_iter
class vFLClient(Client):
"""
The client class for vertical FL, which customizes the handled
functions. Please refer to the tutorial for more details about the
implementation algorithm
Implementation of Vertical FL refer to `Private federated learning on
vertically partitioned data via entity resolution and additively
homomorphic encryption` [Hardy, et al., 2017]
(https://arxiv.org/abs/1711.10677)
"""
def __init__(self,
ID=-1,
server_id=None,
state=-1,
config=None,
data=None,
model=None,
device='cpu',
strategy=None,
*args,
**kwargs):
super(vFLClient,
self).__init__(ID, server_id, state, config, data, model, device,
strategy, *args, **kwargs)
self.data = data
self.public_key = None
self.theta = None
self.batch_index = None
self._init_data_related_var()
self.register_handlers('public_keys',
self.callback_funcs_for_public_keys)
self.register_handlers('model_para',
self.callback_funcs_for_model_para)
self.register_handlers('encryped_gradient_u',
self.callback_funcs_for_encryped_gradient_u)
self.register_handlers('encryped_gradient_v',
self.callback_funcs_for_encryped_gradient_v)
def _init_data_related_var(self):
self.own_label = ('y' in self.data['train'])
self.dataloader = batch_iter(self.data['train'],
self._cfg.dataloader.batch_size,
shuffled=True)
def sample_data(self, index=None):
if index is None:
assert self.own_label
return next(self.dataloader)
else:
return self.data['train']['x'][index]
def callback_funcs_for_public_keys(self, message: Message):
self.public_key = message.content
def callback_funcs_for_model_para(self, message: Message):
self.theta = message.content
if self.own_label:
index, input_x, input_y = self.sample_data()
self.batch_index = index
u_A = 0.25 * np.matmul(input_x, self.theta) - 0.5 * input_y
en_u_A = [self.public_key.encrypt(x) for x in u_A]
self.comm_manager.send(
Message(msg_type='encryped_gradient_u',
sender=self.ID,
receiver=[
each for each in self.comm_manager.neighbors
if each != self.server_id
],
state=self.state,
content=(self.batch_index, en_u_A)))
def callback_funcs_for_encryped_gradient_u(self, message: Message):
index, en_u_A = message.content
self.batch_index = index
input_x = self.sample_data(index=self.batch_index)
u_B = 0.25 * np.matmul(input_x, self.theta)
en_u_B = [self.public_key.encrypt(x) for x in u_B]
en_u = np.expand_dims([sum(x) for x in zip(en_u_A, en_u_B)], -1)
en_v_B = en_u * input_x
self.comm_manager.send(
Message(msg_type='encryped_gradient_v',
sender=self.ID,
receiver=[
each for each in self.comm_manager.neighbors
if each != self.server_id
],
state=self.state,
content=(en_u, en_v_B)))
def callback_funcs_for_encryped_gradient_v(self, message: Message):
en_u, en_v_B = message.content
input_x = self.sample_data(index=self.batch_index)
en_v_A = en_u * input_x
en_v = np.concatenate([en_v_B, en_v_A], axis=-1)
self.comm_manager.send(
Message(msg_type='encryped_gradient',
sender=self.ID,
receiver=[self.server_id],
state=self.state,
content=(len(input_x), en_v)))