FS-TFP/federatedscope/core/sampler.py

167 lines
5.3 KiB
Python

from abc import ABC, abstractmethod
import numpy as np
class Sampler(ABC):
"""
The strategies of sampling clients for each training round
Arguments:
client_state: a dict to manager the state of clients (idle or busy)
"""
def __init__(self, client_num):
self.client_state = np.asarray([1] * (client_num + 1))
# Set the state of server (index=0) to 'working'
self.client_state[0] = 0
@abstractmethod
def sample(self, size):
raise NotImplementedError
def change_state(self, indices, state):
"""
To modify the state of clients (idle or working)
"""
if isinstance(indices, list) or isinstance(indices, np.ndarray):
all_idx = indices
else:
all_idx = [indices]
for idx in all_idx:
if state in ['idle', 'seen']:
self.client_state[idx] = 1
elif state in ['working', 'unseen']:
self.client_state[idx] = 0
else:
raise ValueError(
f"The state of client should be one of "
f"['idle', 'working', 'unseen], but got {state}")
class UniformSampler(Sampler):
"""
To uniformly sample the clients from all the idle clients
"""
def __init__(self, client_num):
super(UniformSampler, self).__init__(client_num)
def sample(self, size):
"""
To sample clients
"""
idle_clients = np.nonzero(self.client_state)[0]
sampled_clients = np.random.choice(idle_clients,
size=size,
replace=False).tolist()
self.change_state(sampled_clients, 'working')
return sampled_clients
class GroupSampler(Sampler):
"""
To grouply sample the clients based on their responsiveness (or other
client information of the clients)
"""
def __init__(self, client_num, client_info, bins=10):
super(GroupSampler, self).__init__(client_num)
self.bins = bins
self.update_client_info(client_info)
self.candidate_iterator = self.partition()
def update_client_info(self, client_info):
"""
To update the client information
"""
self.client_info = np.asarray(
[1.0] + [x for x in client_info
]) # client_info[0] is preversed for the server
assert len(self.client_info) == len(
self.client_state
), "The first dimension of client_info is mismatched with client_num"
def partition(self):
"""
To partition the clients into groups according to the client
information
Arguments:
:returns: a iteration of candidates
"""
sorted_index = np.argsort(self.client_info)
num_per_bins = np.int(len(sorted_index) / self.bins)
# grouped clients
self.grouped_clients = np.split(
sorted_index, np.cumsum([num_per_bins] * (self.bins - 1)))
return self.permutation()
def permutation(self):
candidates = list()
permutation = np.random.permutation(np.arange(self.bins))
for i in permutation:
np.random.shuffle(self.grouped_clients[i])
candidates.extend(self.grouped_clients[i])
return iter(candidates)
def sample(self, size, shuffle=False):
"""
To sample clients
"""
if shuffle:
self.candidate_iterator = self.permutation()
sampled_clients = list()
for i in range(size):
# To find an idle client
while True:
try:
item = next(self.candidate_iterator)
except StopIteration:
self.candidate_iterator = self.permutation()
item = next(self.candidate_iterator)
if self.client_state[item] == 1:
break
sampled_clients.append(item)
self.change_state(item, 'working')
return sampled_clients
class ResponsivenessRealtedSampler(Sampler):
"""
To sample the clients based on their responsiveness (or other information
of clients)
"""
def __init__(self, client_num, client_info):
super(ResponsivenessRealtedSampler, self).__init__(client_num)
self.update_client_info(client_info)
def update_client_info(self, client_info):
"""
To update the client information
"""
self.client_info = np.asarray(
[1.0] + [np.sqrt(x) for x in client_info
]) # client_info[0] is preversed for the server
assert len(self.client_info) == len(
self.client_state
), "The first dimension of client_info is mismatched with client_num"
def sample(self, size):
"""
To sample clients
"""
idle_clients = np.nonzero(self.client_state)[0]
client_info = self.client_info[idle_clients]
client_info = client_info / np.sum(client_info, keepdims=True)
sampled_clients = np.random.choice(idle_clients,
p=client_info,
size=size,
replace=False).tolist()
self.change_state(sampled_clients, 'working')
return sampled_clients