167 lines
5.3 KiB
Python
167 lines
5.3 KiB
Python
from abc import ABC, abstractmethod
|
|
|
|
import numpy as np
|
|
|
|
|
|
class Sampler(ABC):
|
|
"""
|
|
The strategies of sampling clients for each training round
|
|
|
|
Arguments:
|
|
client_state: a dict to manager the state of clients (idle or busy)
|
|
"""
|
|
def __init__(self, client_num):
|
|
self.client_state = np.asarray([1] * (client_num + 1))
|
|
# Set the state of server (index=0) to 'working'
|
|
self.client_state[0] = 0
|
|
|
|
@abstractmethod
|
|
def sample(self, size):
|
|
raise NotImplementedError
|
|
|
|
def change_state(self, indices, state):
|
|
"""
|
|
To modify the state of clients (idle or working)
|
|
"""
|
|
if isinstance(indices, list) or isinstance(indices, np.ndarray):
|
|
all_idx = indices
|
|
else:
|
|
all_idx = [indices]
|
|
for idx in all_idx:
|
|
if state in ['idle', 'seen']:
|
|
self.client_state[idx] = 1
|
|
elif state in ['working', 'unseen']:
|
|
self.client_state[idx] = 0
|
|
else:
|
|
raise ValueError(
|
|
f"The state of client should be one of "
|
|
f"['idle', 'working', 'unseen], but got {state}")
|
|
|
|
|
|
class UniformSampler(Sampler):
|
|
"""
|
|
To uniformly sample the clients from all the idle clients
|
|
"""
|
|
def __init__(self, client_num):
|
|
super(UniformSampler, self).__init__(client_num)
|
|
|
|
def sample(self, size):
|
|
"""
|
|
To sample clients
|
|
"""
|
|
idle_clients = np.nonzero(self.client_state)[0]
|
|
sampled_clients = np.random.choice(idle_clients,
|
|
size=size,
|
|
replace=False).tolist()
|
|
self.change_state(sampled_clients, 'working')
|
|
return sampled_clients
|
|
|
|
|
|
class GroupSampler(Sampler):
|
|
"""
|
|
To grouply sample the clients based on their responsiveness (or other
|
|
client information of the clients)
|
|
"""
|
|
def __init__(self, client_num, client_info, bins=10):
|
|
super(GroupSampler, self).__init__(client_num)
|
|
self.bins = bins
|
|
self.update_client_info(client_info)
|
|
self.candidate_iterator = self.partition()
|
|
|
|
def update_client_info(self, client_info):
|
|
"""
|
|
To update the client information
|
|
"""
|
|
self.client_info = np.asarray(
|
|
[1.0] + [x for x in client_info
|
|
]) # client_info[0] is preversed for the server
|
|
assert len(self.client_info) == len(
|
|
self.client_state
|
|
), "The first dimension of client_info is mismatched with client_num"
|
|
|
|
def partition(self):
|
|
"""
|
|
To partition the clients into groups according to the client
|
|
information
|
|
|
|
Arguments:
|
|
:returns: a iteration of candidates
|
|
"""
|
|
sorted_index = np.argsort(self.client_info)
|
|
num_per_bins = np.int(len(sorted_index) / self.bins)
|
|
|
|
# grouped clients
|
|
self.grouped_clients = np.split(
|
|
sorted_index, np.cumsum([num_per_bins] * (self.bins - 1)))
|
|
|
|
return self.permutation()
|
|
|
|
def permutation(self):
|
|
candidates = list()
|
|
permutation = np.random.permutation(np.arange(self.bins))
|
|
for i in permutation:
|
|
np.random.shuffle(self.grouped_clients[i])
|
|
candidates.extend(self.grouped_clients[i])
|
|
|
|
return iter(candidates)
|
|
|
|
def sample(self, size, shuffle=False):
|
|
"""
|
|
To sample clients
|
|
"""
|
|
if shuffle:
|
|
self.candidate_iterator = self.permutation()
|
|
|
|
sampled_clients = list()
|
|
for i in range(size):
|
|
# To find an idle client
|
|
while True:
|
|
try:
|
|
item = next(self.candidate_iterator)
|
|
except StopIteration:
|
|
self.candidate_iterator = self.permutation()
|
|
item = next(self.candidate_iterator)
|
|
|
|
if self.client_state[item] == 1:
|
|
break
|
|
|
|
sampled_clients.append(item)
|
|
self.change_state(item, 'working')
|
|
|
|
return sampled_clients
|
|
|
|
|
|
class ResponsivenessRealtedSampler(Sampler):
|
|
"""
|
|
To sample the clients based on their responsiveness (or other information
|
|
of clients)
|
|
"""
|
|
def __init__(self, client_num, client_info):
|
|
super(ResponsivenessRealtedSampler, self).__init__(client_num)
|
|
self.update_client_info(client_info)
|
|
|
|
def update_client_info(self, client_info):
|
|
"""
|
|
To update the client information
|
|
"""
|
|
self.client_info = np.asarray(
|
|
[1.0] + [np.sqrt(x) for x in client_info
|
|
]) # client_info[0] is preversed for the server
|
|
assert len(self.client_info) == len(
|
|
self.client_state
|
|
), "The first dimension of client_info is mismatched with client_num"
|
|
|
|
def sample(self, size):
|
|
"""
|
|
To sample clients
|
|
"""
|
|
idle_clients = np.nonzero(self.client_state)[0]
|
|
client_info = self.client_info[idle_clients]
|
|
client_info = client_info / np.sum(client_info, keepdims=True)
|
|
sampled_clients = np.random.choice(idle_clients,
|
|
p=client_info,
|
|
size=size,
|
|
replace=False).tolist()
|
|
self.change_state(sampled_clients, 'working')
|
|
return sampled_clients
|