FS-TFP/federatedscope/core/splitters/utils.py

88 lines
3.2 KiB
Python

import numpy as np
def _split_according_to_prior(label, client_num, prior):
assert client_num == len(prior)
classes = len(np.unique(label))
assert classes == len(np.unique(np.concatenate(prior, 0)))
# counting
frequency = np.zeros(shape=(client_num, classes))
for idx, client_prior in enumerate(prior):
for each in client_prior:
frequency[idx][each] += 1
sum_frequency = np.sum(frequency, axis=0)
idx_slice = [[] for _ in range(client_num)]
for k in range(classes):
idx_k = np.where(label == k)[0]
np.random.shuffle(idx_k)
nums_k = np.ceil(frequency[:, k] / sum_frequency[k] *
len(idx_k)).astype(int)
while len(idx_k) < np.sum(nums_k):
random_client = np.random.choice(range(client_num))
if nums_k[random_client] > 0:
nums_k[random_client] -= 1
assert len(idx_k) == np.sum(nums_k)
idx_slice = [
idx_j + idx.tolist() for idx_j, idx in zip(
idx_slice, np.split(idx_k,
np.cumsum(nums_k)[:-1]))
]
for i in range(len(idx_slice)):
np.random.shuffle(idx_slice[i])
return idx_slice
def dirichlet_distribution_noniid_slice(label,
client_num,
alpha,
min_size=1,
prior=None):
r"""Get sample index list for each client from the Dirichlet distribution.
https://github.com/FedML-AI/FedML/blob/master/fedml_core/non_iid
partition/noniid_partition.py
Arguments:
label (np.array): Label list to be split.
client_num (int): Split label into client_num parts.
alpha (float): alpha of LDA.
min_size (int): min number of sample in each client
Returns:
idx_slice (List): List of splited label index slice.
"""
if len(label.shape) != 1:
raise ValueError('Only support single-label tasks!')
if prior is not None:
return _split_according_to_prior(label, client_num, prior)
num = len(label)
classes = len(np.unique(label))
assert num > client_num * min_size, f'The number of sample should be ' \
f'greater than' \
f' {client_num * min_size}.'
size = 0
while size < min_size:
idx_slice = [[] for _ in range(client_num)]
for k in range(classes):
# for label k
idx_k = np.where(label == k)[0]
np.random.shuffle(idx_k)
prop = np.random.dirichlet(np.repeat(alpha, client_num))
# prop = np.array([
# p * (len(idx_j) < num / client_num)
# for p, idx_j in zip(prop, idx_slice)
# ])
# prop = prop / sum(prop)
prop = (np.cumsum(prop) * len(idx_k)).astype(int)[:-1]
idx_slice = [
idx_j + idx.tolist()
for idx_j, idx in zip(idx_slice, np.split(idx_k, prop))
]
size = min([len(idx_j) for idx_j in idx_slice])
for i in range(client_num):
np.random.shuffle(idx_slice[i])
return idx_slice