from abc import ABC, abstractmethod import numpy as np class Sampler(ABC): """ The strategies of sampling clients for each training round Arguments: client_state: a dict to manager the state of clients (idle or busy) """ def __init__(self, client_num): self.client_state = np.asarray([1] * (client_num + 1)) # Set the state of server (index=0) to 'working' self.client_state[0] = 0 @abstractmethod def sample(self, size): raise NotImplementedError def change_state(self, indices, state): """ To modify the state of clients (idle or working) """ if isinstance(indices, list) or isinstance(indices, np.ndarray): all_idx = indices else: all_idx = [indices] for idx in all_idx: if state in ['idle', 'seen']: self.client_state[idx] = 1 elif state in ['working', 'unseen']: self.client_state[idx] = 0 else: raise ValueError( f"The state of client should be one of " f"['idle', 'working', 'unseen], but got {state}") class UniformSampler(Sampler): """ To uniformly sample the clients from all the idle clients """ def __init__(self, client_num): super(UniformSampler, self).__init__(client_num) def sample(self, size): """ To sample clients """ idle_clients = np.nonzero(self.client_state)[0] sampled_clients = np.random.choice(idle_clients, size=size, replace=False).tolist() self.change_state(sampled_clients, 'working') return sampled_clients class GroupSampler(Sampler): """ To grouply sample the clients based on their responsiveness (or other client information of the clients) """ def __init__(self, client_num, client_info, bins=10): super(GroupSampler, self).__init__(client_num) self.bins = bins self.update_client_info(client_info) self.candidate_iterator = self.partition() def update_client_info(self, client_info): """ To update the client information """ self.client_info = np.asarray( [1.0] + [x for x in client_info ]) # client_info[0] is preversed for the server assert len(self.client_info) == len( self.client_state ), "The first dimension of client_info is mismatched with client_num" def partition(self): """ To partition the clients into groups according to the client information Arguments: :returns: a iteration of candidates """ sorted_index = np.argsort(self.client_info) num_per_bins = np.int(len(sorted_index) / self.bins) # grouped clients self.grouped_clients = np.split( sorted_index, np.cumsum([num_per_bins] * (self.bins - 1))) return self.permutation() def permutation(self): candidates = list() permutation = np.random.permutation(np.arange(self.bins)) for i in permutation: np.random.shuffle(self.grouped_clients[i]) candidates.extend(self.grouped_clients[i]) return iter(candidates) def sample(self, size, shuffle=False): """ To sample clients """ if shuffle: self.candidate_iterator = self.permutation() sampled_clients = list() for i in range(size): # To find an idle client while True: try: item = next(self.candidate_iterator) except StopIteration: self.candidate_iterator = self.permutation() item = next(self.candidate_iterator) if self.client_state[item] == 1: break sampled_clients.append(item) self.change_state(item, 'working') return sampled_clients class ResponsivenessRealtedSampler(Sampler): """ To sample the clients based on their responsiveness (or other information of clients) """ def __init__(self, client_num, client_info): super(ResponsivenessRealtedSampler, self).__init__(client_num) self.update_client_info(client_info) def update_client_info(self, client_info): """ To update the client information """ self.client_info = np.asarray( [1.0] + [np.sqrt(x) for x in client_info ]) # client_info[0] is preversed for the server assert len(self.client_info) == len( self.client_state ), "The first dimension of client_info is mismatched with client_num" def sample(self, size): """ To sample clients """ idle_clients = np.nonzero(self.client_state)[0] client_info = self.client_info[idle_clients] client_info = client_info / np.sum(client_info, keepdims=True) sampled_clients = np.random.choice(idle_clients, p=client_info, size=size, replace=False).tolist() self.change_state(sampled_clients, 'working') return sampled_clients