FS-TFP/federatedscope/core/secret_sharing/secret_sharing.py

99 lines
3.3 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
try:
import torch
except ImportError:
torch = None
class SecretSharing(ABC):
def __init__(self):
pass
@abstractmethod
def secret_split(self, secret):
pass
@abstractmethod
def secret_reconstruct(self, secret_seq):
pass
class AdditiveSecretSharing(SecretSharing):
"""
AdditiveSecretSharing class, which can split a number into frames and
recover it by summing up
"""
def __init__(self, shared_party_num, size=60):
super(SecretSharing, self).__init__()
assert shared_party_num > 1, "AdditiveSecretSharing require " \
"shared_party_num > 1"
self.shared_party_num = shared_party_num
self.maximum = 2**size
self.mod_number = 2 * self.maximum + 1
self.epsilon = 1e8
self.mod_funs = np.vectorize(lambda x: x % self.mod_number)
self.float2fixedpoint = np.vectorize(self._float2fixedpoint)
self.fixedpoint2float = np.vectorize(self._fixedpoint2float)
def secret_split(self, secret):
"""
To split the secret into frames according to the shared_party_num
"""
if isinstance(secret, dict):
secret_list = [dict() for _ in range(self.shared_party_num)]
for key in secret:
for idx, each in enumerate(self.secret_split(secret[key])):
secret_list[idx][key] = each
return secret_list
if isinstance(secret, list) or isinstance(secret, np.ndarray):
secret = np.asarray(secret)
shape = [self.shared_party_num - 1] + list(secret.shape)
elif isinstance(secret, torch.Tensor):
secret = secret.numpy()
shape = [self.shared_party_num - 1] + list(secret.shape)
else:
shape = [self.shared_party_num - 1]
secret = self.float2fixedpoint(secret)
secret_seq = np.random.randint(low=0, high=self.mod_number, size=shape)
# last_seq = self.mod_funs(secret - self.mod_funs(np.sum(secret_seq,
# axis=0)))
last_seq = self.mod_funs(secret -
self.mod_funs(np.sum(secret_seq, axis=0)))
secret_seq = np.append(secret_seq,
np.expand_dims(last_seq, axis=0),
axis=0)
return secret_seq
def secret_reconstruct(self, secret_seq):
"""
To recover the secret
"""
assert len(secret_seq) == self.shared_party_num
merge_model = secret_seq[0].copy()
if isinstance(merge_model, dict):
for key in merge_model:
for idx in range(len(secret_seq)):
if idx == 0:
merge_model[key] = secret_seq[idx][key]
else:
merge_model[key] += secret_seq[idx][key]
merge_model[key] = self.fixedpoint2float(merge_model[key])
return merge_model
def _float2fixedpoint(self, x):
x = round(x * self.epsilon, 0)
assert abs(x) < self.maximum
return x % self.mod_number
def _fixedpoint2float(self, x):
x = x % self.mod_number
if x > self.maximum:
return -1 * (self.mod_number - x) / self.epsilon
else:
return x / self.epsilon