129 lines
4.9 KiB
Python
129 lines
4.9 KiB
Python
import numpy as np
|
|
import logging
|
|
|
|
from federatedscope.core.workers import Server
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.vertical_fl.Paillier import abstract_paillier
|
|
from federatedscope.core.auxiliaries.model_builder import get_model
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class vFLServer(Server):
|
|
"""
|
|
The server class for vertical FL, which customizes the handled
|
|
functions. Please refer to the tutorial for more details about the
|
|
implementation algorithm
|
|
Implementation of Vertical FL refer to `Private federated learning on
|
|
vertically partitioned data via entity resolution and additively
|
|
homomorphic encryption` [Hardy, et al., 2017]
|
|
(https://arxiv.org/abs/1711.10677)
|
|
"""
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
super(vFLServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
cfg_key_size = config.vertical.key_size
|
|
self.public_key, self.private_key = \
|
|
abstract_paillier.generate_paillier_keypair(n_length=cfg_key_size)
|
|
self.vertical_dims = config.vertical.dims
|
|
self._init_data_related_var()
|
|
|
|
self.lr = config.train.optimizer.lr
|
|
|
|
self.register_handlers('encryped_gradient',
|
|
self.callback_funcs_for_encryped_gradient)
|
|
|
|
def _init_data_related_var(self):
|
|
self.dims = [0] + self.vertical_dims
|
|
self.model = get_model(self._cfg.model, self.data)
|
|
self.theta = self.model.state_dict()['fc.weight'].numpy().reshape(-1)
|
|
|
|
def trigger_for_start(self):
|
|
if self.check_client_join_in():
|
|
self.broadcast_public_keys()
|
|
self.broadcast_client_address()
|
|
self.trigger_for_feat_engr(self.broadcast_model_para)
|
|
|
|
def broadcast_public_keys(self):
|
|
self.comm_manager.send(
|
|
Message(msg_type='public_keys',
|
|
sender=self.ID,
|
|
receiver=list(self.comm_manager.get_neighbors().keys()),
|
|
state=self.state,
|
|
content=self.public_key))
|
|
|
|
def broadcast_model_para(self):
|
|
|
|
client_ids = self.comm_manager.neighbors.keys()
|
|
cur_idx = 0
|
|
for client_id in client_ids:
|
|
theta_slices = self.theta[cur_idx:cur_idx +
|
|
self.dims[int(client_id)]]
|
|
self.comm_manager.send(
|
|
Message(msg_type='model_para',
|
|
sender=self.ID,
|
|
receiver=client_id,
|
|
state=self.state,
|
|
content=theta_slices))
|
|
cur_idx += self.dims[int(client_id)]
|
|
|
|
def callback_funcs_for_encryped_gradient(self, message: Message):
|
|
sample_num, en_v = message.content
|
|
|
|
v = np.reshape(
|
|
[self.private_key.decrypt(x) for x in np.reshape(en_v, -1)],
|
|
[sample_num, -1])
|
|
avg_gradients = np.mean(v, axis=0)
|
|
self.theta = self.theta - self.lr * avg_gradients
|
|
|
|
self.state += 1
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
metrics = self.evaluate()
|
|
self._monitor.update_best_result(self.best_results,
|
|
metrics,
|
|
results_type='server_global_eval')
|
|
formatted_logs = self._monitor.format_eval_res(
|
|
metrics,
|
|
rnd=self.state,
|
|
role='Server #',
|
|
forms=self._cfg.eval.report)
|
|
logger.info(formatted_logs)
|
|
|
|
if self.state < self.total_round_num:
|
|
# Move to next round of training
|
|
logger.info(f'----------- Starting a new training round (Round '
|
|
f'#{self.state}) -------------')
|
|
self.broadcast_model_para()
|
|
else:
|
|
metrics = self.evaluate()
|
|
self._monitor.update_best_result(self.best_results,
|
|
metrics,
|
|
results_type='server_global_eval')
|
|
formatted_logs = self._monitor.format_eval_res(
|
|
metrics,
|
|
rnd=self.state,
|
|
role='Server #',
|
|
forms=self._cfg.eval.report)
|
|
logger.info(formatted_logs)
|
|
|
|
def evaluate(self):
|
|
test_x = self.data['test']['x']
|
|
test_y = self.data['test']['y']
|
|
loss = np.mean(
|
|
np.log(1 + np.exp(-test_y * np.matmul(test_x, self.theta))))
|
|
acc = np.mean((test_y * np.matmul(test_x, self.theta)) > 0)
|
|
|
|
return {'test_loss': loss, 'test_acc': acc, 'test_total': len(test_y)}
|