88 lines
3.2 KiB
Python
88 lines
3.2 KiB
Python
import numpy as np
|
|
|
|
|
|
def _split_according_to_prior(label, client_num, prior):
|
|
assert client_num == len(prior)
|
|
classes = len(np.unique(label))
|
|
assert classes == len(np.unique(np.concatenate(prior, 0)))
|
|
|
|
# counting
|
|
frequency = np.zeros(shape=(client_num, classes))
|
|
for idx, client_prior in enumerate(prior):
|
|
for each in client_prior:
|
|
frequency[idx][each] += 1
|
|
sum_frequency = np.sum(frequency, axis=0)
|
|
|
|
idx_slice = [[] for _ in range(client_num)]
|
|
for k in range(classes):
|
|
idx_k = np.where(label == k)[0]
|
|
np.random.shuffle(idx_k)
|
|
nums_k = np.ceil(frequency[:, k] / sum_frequency[k] *
|
|
len(idx_k)).astype(int)
|
|
while len(idx_k) < np.sum(nums_k):
|
|
random_client = np.random.choice(range(client_num))
|
|
if nums_k[random_client] > 0:
|
|
nums_k[random_client] -= 1
|
|
assert len(idx_k) == np.sum(nums_k)
|
|
idx_slice = [
|
|
idx_j + idx.tolist() for idx_j, idx in zip(
|
|
idx_slice, np.split(idx_k,
|
|
np.cumsum(nums_k)[:-1]))
|
|
]
|
|
|
|
for i in range(len(idx_slice)):
|
|
np.random.shuffle(idx_slice[i])
|
|
return idx_slice
|
|
|
|
|
|
def dirichlet_distribution_noniid_slice(label,
|
|
client_num,
|
|
alpha,
|
|
min_size=1,
|
|
prior=None):
|
|
r"""Get sample index list for each client from the Dirichlet distribution.
|
|
https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid
|
|
partition/noniid_partition.py
|
|
|
|
Arguments:
|
|
label (np.array): Label list to be split.
|
|
client_num (int): Split label into client_num parts.
|
|
alpha (float): alpha of LDA.
|
|
min_size (int): min number of sample in each client
|
|
Returns:
|
|
idx_slice (List): List of splited label index slice.
|
|
"""
|
|
if len(label.shape) != 1:
|
|
raise ValueError('Only support single-label tasks!')
|
|
|
|
if prior is not None:
|
|
return _split_according_to_prior(label, client_num, prior)
|
|
|
|
num = len(label)
|
|
classes = len(np.unique(label))
|
|
assert num > client_num * min_size, f'The number of sample should be ' \
|
|
f'greater than' \
|
|
f' {client_num * min_size}.'
|
|
size = 0
|
|
while size < min_size:
|
|
idx_slice = [[] for _ in range(client_num)]
|
|
for k in range(classes):
|
|
# for label k
|
|
idx_k = np.where(label == k)[0]
|
|
np.random.shuffle(idx_k)
|
|
prop = np.random.dirichlet(np.repeat(alpha, client_num))
|
|
# prop = np.array([
|
|
# p * (len(idx_j) < num / client_num)
|
|
# for p, idx_j in zip(prop, idx_slice)
|
|
# ])
|
|
# prop = prop / sum(prop)
|
|
prop = (np.cumsum(prop) * len(idx_k)).astype(int)[:-1]
|
|
idx_slice = [
|
|
idx_j + idx.tolist()
|
|
for idx_j, idx in zip(idx_slice, np.split(idx_k, prop))
|
|
]
|
|
size = min([len(idx_j) for idx_j in idx_slice])
|
|
for i in range(client_num):
|
|
np.random.shuffle(idx_slice[i])
|
|
return idx_slice
|