534 lines
22 KiB
Python
534 lines
22 KiB
Python
import os
|
|
import logging
|
|
from itertools import product
|
|
|
|
import yaml
|
|
|
|
import numpy as np
|
|
from numpy.linalg import norm
|
|
from scipy.special import logsumexp
|
|
import torch
|
|
|
|
from federatedscope.core.message import Message
|
|
from federatedscope.core.workers import Server
|
|
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
|
|
from federatedscope.autotune.fedex.utils import HyperNet
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def discounted_mean(trace, factor=1.0):
|
|
|
|
weight = factor**np.flip(np.arange(len(trace)), axis=0)
|
|
|
|
return np.inner(trace, weight) / weight.sum()
|
|
|
|
|
|
class FedExServer(Server):
|
|
"""Some code snippets are borrowed from the open-sourced FedEx (
|
|
https://github.com/mkhodak/FedEx)
|
|
"""
|
|
def __init__(self,
|
|
ID=-1,
|
|
state=0,
|
|
config=None,
|
|
data=None,
|
|
model=None,
|
|
client_num=5,
|
|
total_round_num=10,
|
|
device='cpu',
|
|
strategy=None,
|
|
**kwargs):
|
|
|
|
super(FedExServer,
|
|
self).__init__(ID, state, config, data, model, client_num,
|
|
total_round_num, device, strategy, **kwargs)
|
|
|
|
# initialize action space and the policy
|
|
with open(config.hpo.fedex.ss, 'r') as ips:
|
|
ss = yaml.load(ips, Loader=yaml.FullLoader)
|
|
|
|
if next(iter(ss.keys())).startswith('arm'):
|
|
# This is a flattened action space
|
|
# ensure the order is unchanged
|
|
ss = sorted([(int(k[3:]), v) for k, v in ss.items()],
|
|
key=lambda x: x[0])
|
|
self._grid = []
|
|
self._cfsp = [[tp[1] for tp in ss]]
|
|
else:
|
|
# This is not a flat search space
|
|
# be careful for the order
|
|
self._grid = sorted(ss.keys())
|
|
self._cfsp = [ss[pn] for pn in self._grid]
|
|
|
|
sizes = [len(cand_set) for cand_set in self._cfsp]
|
|
eta0 = 'auto' if config.hpo.fedex.eta0 <= .0 else float(
|
|
config.hpo.fedex.eta0)
|
|
self._eta0 = [
|
|
np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0
|
|
for size in sizes
|
|
]
|
|
self._sched = config.hpo.fedex.sched
|
|
self._cutoff = config.hpo.fedex.cutoff
|
|
self._baseline = config.hpo.fedex.gamma
|
|
self._diff = config.hpo.fedex.diff
|
|
if self._cfg.hpo.fedex.psn:
|
|
# personalized policy
|
|
# TODO: client-wise RFF
|
|
self._client_encodings = torch.randn(
|
|
(client_num, 8), device=device) / np.sqrt(8)
|
|
self._policy_net = HyperNet(
|
|
self._client_encodings.shape[-1],
|
|
sizes,
|
|
client_num,
|
|
device,
|
|
).to(device)
|
|
self._policy_net.eval()
|
|
theta4stat = [
|
|
theta.detach().cpu().numpy()
|
|
for theta in self._policy_net(self._client_encodings)
|
|
]
|
|
self._pn_optimizer = torch.optim.Adam(
|
|
self._policy_net.parameters(),
|
|
lr=self._cfg.hpo.fedex.pi_lr,
|
|
weight_decay=1e-5)
|
|
else:
|
|
self._z = [np.full(size, -np.log(size)) for size in sizes]
|
|
self._theta = [np.exp(z) for z in self._z]
|
|
theta4stat = self._theta
|
|
self._store = [0.0 for _ in sizes]
|
|
self._stop_exploration = False
|
|
self._trace = {
|
|
'global': [],
|
|
'refine': [],
|
|
'entropy': [self.entropy(theta4stat)],
|
|
'mle': [self.mle(theta4stat)]
|
|
}
|
|
|
|
if self._cfg.federate.restore_from != '':
|
|
if not os.path.exists(self._cfg.federate.restore_from):
|
|
logger.warning(f'Invalid `restore_from`:'
|
|
f' {self._cfg.federate.restore_from}.')
|
|
else:
|
|
pi_ckpt_path = self._cfg.federate.restore_from[
|
|
:self._cfg.federate.restore_from.rfind('.')] \
|
|
+ "_fedex.yaml"
|
|
with open(pi_ckpt_path, 'r') as ips:
|
|
ckpt = yaml.load(ips, Loader=yaml.FullLoader)
|
|
if self._cfg.hpo.fedex.psn:
|
|
psn_pi_ckpt_path = self._cfg.federate.restore_from[
|
|
:self._cfg.federate.restore_from.rfind('.')] \
|
|
+ "_pfedex.pt"
|
|
psn_pi = torch.load(psn_pi_ckpt_path, map_location=device)
|
|
self._client_encodings = psn_pi['client_encodings']
|
|
self._policy_net.load_state_dict(psn_pi['policy_net'])
|
|
else:
|
|
self._z = [np.asarray(z) for z in ckpt['z']]
|
|
self._theta = [np.exp(z) for z in self._z]
|
|
self._store = ckpt['store']
|
|
self._stop_exploration = ckpt['stop']
|
|
self._trace = dict()
|
|
self._trace['global'] = ckpt['global']
|
|
self._trace['refine'] = ckpt['refine']
|
|
self._trace['entropy'] = ckpt['entropy']
|
|
self._trace['mle'] = ckpt['mle']
|
|
|
|
def entropy(self, thetas):
|
|
if self._cfg.hpo.fedex.psn:
|
|
entropy = 0.0
|
|
for i in range(thetas[0].shape[0]):
|
|
for probs in product(*(theta[i][theta[i] > 0.0]
|
|
for theta in thetas)):
|
|
prob = np.prod(probs)
|
|
entropy -= prob * np.log(prob)
|
|
return entropy / float(thetas[0].shape[0])
|
|
else:
|
|
entropy = 0.0
|
|
for probs in product(*(theta[theta > 0.0] for theta in thetas)):
|
|
prob = np.prod(probs)
|
|
entropy -= prob * np.log(prob)
|
|
return entropy
|
|
|
|
def mle(self, thetas):
|
|
if self._cfg.hpo.fedex.psn:
|
|
return np.prod([theta.max(-1) for theta in thetas], 0).mean()
|
|
else:
|
|
return np.prod([theta.max() for theta in thetas])
|
|
|
|
def trace(self, key):
|
|
'''returns trace of one of three tracked quantities
|
|
Args:
|
|
key (str): 'entropy', 'global', or 'refine'
|
|
Returns:
|
|
numpy vector with length equal to number of rounds up to now.
|
|
'''
|
|
|
|
return np.array(self._trace[key])
|
|
|
|
def sample(self, thetas):
|
|
"""samples from configs using current probability vector
|
|
Arguments:
|
|
thetas (list): probabilities for the hyperparameters.
|
|
"""
|
|
|
|
# determine index
|
|
if self._stop_exploration:
|
|
cfg_idx = [int(theta.argmax()) for theta in thetas]
|
|
else:
|
|
cfg_idx = [
|
|
np.random.choice(len(theta), p=theta) for theta in thetas
|
|
]
|
|
|
|
# get the sampled value(s)
|
|
if self._grid:
|
|
sampled_cfg = {
|
|
pn: cands[i]
|
|
for pn, cands, i in zip(self._grid, self._cfsp, cfg_idx)
|
|
}
|
|
else:
|
|
sampled_cfg = self._cfsp[0][cfg_idx[0]]
|
|
|
|
return cfg_idx, sampled_cfg
|
|
|
|
def broadcast_model_para(self,
|
|
msg_type='model_para',
|
|
sample_client_num=-1,
|
|
filter_unseen_clients=True):
|
|
"""
|
|
To broadcast the message to all clients or sampled clients
|
|
"""
|
|
if filter_unseen_clients:
|
|
# to filter out the unseen clients when sampling
|
|
self.sampler.change_state(self.unseen_clients_id, 'unseen')
|
|
|
|
if sample_client_num > 0:
|
|
receiver = self.sampler.sample(size=sample_client_num)
|
|
else:
|
|
# broadcast to all clients
|
|
receiver = list(self.comm_manager.neighbors.keys())
|
|
if msg_type == 'model_para':
|
|
self.sampler.change_state(receiver, 'working')
|
|
|
|
if self._noise_injector is not None and msg_type == 'model_para':
|
|
# Inject noise only when broadcast parameters
|
|
for model_idx_i in range(len(self.models)):
|
|
num_sample_clients = [
|
|
v["num_sample"] for v in self.join_in_info.values()
|
|
]
|
|
self._noise_injector(self._cfg, num_sample_clients,
|
|
self.models[model_idx_i])
|
|
|
|
if self.model_num > 1:
|
|
model_para = [model.state_dict() for model in self.models]
|
|
else:
|
|
model_para = self.model.state_dict()
|
|
|
|
# sample the hyper-parameter config specific to the clients
|
|
if self._cfg.hpo.fedex.psn:
|
|
self._policy_net.train()
|
|
self._pn_optimizer.zero_grad()
|
|
self._theta = self._policy_net(self._client_encodings)
|
|
for rcv_idx in receiver:
|
|
if self._cfg.hpo.fedex.psn:
|
|
cfg_idx, sampled_cfg = self.sample([
|
|
theta[rcv_idx - 1].detach().cpu().numpy()
|
|
for theta in self._theta
|
|
])
|
|
else:
|
|
cfg_idx, sampled_cfg = self.sample(self._theta)
|
|
content = {
|
|
'model_param': model_para,
|
|
"arms": cfg_idx,
|
|
'hyperparam': sampled_cfg
|
|
}
|
|
self.comm_manager.send(
|
|
Message(msg_type=msg_type,
|
|
sender=self.ID,
|
|
receiver=[rcv_idx],
|
|
state=self.state,
|
|
content=content))
|
|
if self._cfg.federate.online_aggr:
|
|
for idx in range(self.model_num):
|
|
self.aggregators[idx].reset()
|
|
|
|
if filter_unseen_clients:
|
|
# restore the state of the unseen clients within sampler
|
|
self.sampler.change_state(self.unseen_clients_id, 'seen')
|
|
|
|
def callback_funcs_model_para(self, message: Message):
|
|
round, sender, content = message.state, message.sender, message.content
|
|
self.sampler.change_state(sender, 'idle')
|
|
# For a new round
|
|
if round not in self.msg_buffer['train'].keys():
|
|
self.msg_buffer['train'][round] = dict()
|
|
|
|
self.msg_buffer['train'][round][sender] = content
|
|
|
|
if self._cfg.federate.online_aggr:
|
|
self.aggregator.inc(tuple(content[0:2]))
|
|
|
|
return self.check_and_move_on()
|
|
|
|
def update_policy(self, feedbacks):
|
|
"""Update the policy. This implementation is borrowed from the
|
|
open-sourced FedEx (
|
|
https://github.com/mkhodak/FedEx/blob/ \
|
|
150fac03857a3239429734d59d319da71191872e/hyper.py#L151)
|
|
Arguments:
|
|
feedbacks (list): each element is a dict containing "arms" and
|
|
necessary feedback.
|
|
"""
|
|
|
|
index = [elem['arms'] for elem in feedbacks]
|
|
cids = [elem['client_id'] for elem in feedbacks]
|
|
before = np.asarray(
|
|
[elem['val_avg_loss_before'] for elem in feedbacks])
|
|
after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks])
|
|
weight = np.asarray([elem['val_total'] for elem in feedbacks],
|
|
dtype=np.float64)
|
|
weight /= np.sum(weight)
|
|
|
|
if self._trace['refine']:
|
|
trace = self.trace('refine')
|
|
if self._diff:
|
|
trace -= self.trace('global')
|
|
baseline = discounted_mean(trace, self._baseline)
|
|
else:
|
|
baseline = 0.0
|
|
self._trace['global'].append(np.inner(before, weight))
|
|
self._trace['refine'].append(np.inner(after, weight))
|
|
if self._stop_exploration:
|
|
self._trace['entropy'].append(0.0)
|
|
self._trace['mle'].append(1.0)
|
|
return
|
|
|
|
if self._cfg.hpo.fedex.psn:
|
|
# policy gradients
|
|
pg_obj = .0
|
|
for i, theta in enumerate(self._theta):
|
|
for idx, cidx, s, w in zip(
|
|
index, cids, after - before if self._diff else after,
|
|
weight):
|
|
pg_obj += w * -1.0 * (s - baseline) * torch.log(
|
|
torch.clip(theta[cidx][idx[i]], min=1e-8, max=1.0))
|
|
pg_loss = -1.0 * pg_obj
|
|
pg_loss.backward()
|
|
self._pn_optimizer.step()
|
|
self._policy_net.eval()
|
|
thetas4stat = [
|
|
theta.detach().cpu().numpy()
|
|
for theta in self._policy_net(self._client_encodings)
|
|
]
|
|
else:
|
|
for i, (z, theta) in enumerate(zip(self._z, self._theta)):
|
|
grad = np.zeros(len(z))
|
|
for idx, s, w in zip(index,
|
|
after - before if self._diff else after,
|
|
weight):
|
|
grad[idx[i]] += w * (s - baseline) / theta[idx[i]]
|
|
if self._sched == 'adaptive':
|
|
self._store[i] += norm(grad, float('inf'))**2
|
|
denom = np.sqrt(self._store[i])
|
|
elif self._sched == 'aggressive':
|
|
denom = 1.0 if np.all(
|
|
grad == 0.0) else norm(grad, float('inf'))
|
|
elif self._sched == 'auto':
|
|
self._store[i] += 1.0
|
|
denom = np.sqrt(self._store[i])
|
|
elif self._sched == 'constant':
|
|
denom = 1.0
|
|
elif self._sched == 'scale':
|
|
denom = 1.0 / np.sqrt(2.0 * np.log(len(grad))) if len(
|
|
grad) > 1 else float('inf')
|
|
else:
|
|
raise NotImplementedError
|
|
eta = self._eta0[i] / denom
|
|
z -= eta * grad
|
|
z -= logsumexp(z)
|
|
self._theta[i] = np.exp(z)
|
|
thetas4stat = self._theta
|
|
|
|
self._trace['entropy'].append(self.entropy(thetas4stat))
|
|
self._trace['mle'].append(self.mle(thetas4stat))
|
|
if self._trace['entropy'][-1] < self._cutoff:
|
|
self._stop_exploration = True
|
|
|
|
logger.info(
|
|
'Server: Updated policy as {} with entropy {:f} and mle {:f}'.
|
|
format(thetas4stat, self._trace['entropy'][-1],
|
|
self._trace['mle'][-1]))
|
|
|
|
def check_and_move_on(self,
|
|
check_eval_result=False,
|
|
min_received_num=None):
|
|
"""
|
|
To check the message_buffer, when enough messages are receiving,
|
|
trigger some events (such as perform aggregation, evaluation,
|
|
and move to the next training round)
|
|
"""
|
|
if min_received_num is None:
|
|
min_received_num = self._cfg.federate.sample_client_num
|
|
assert min_received_num <= self.sample_client_num
|
|
|
|
if check_eval_result:
|
|
min_received_num = len(list(self.comm_manager.neighbors.keys()))
|
|
|
|
move_on_flag = True # To record whether moving to a new training
|
|
# round or finishing the evaluation
|
|
if self.check_buffer(self.state, min_received_num, check_eval_result):
|
|
|
|
if not check_eval_result: # in the training process
|
|
mab_feedbacks = list()
|
|
# Get all the message
|
|
train_msg_buffer = self.msg_buffer['train'][self.state]
|
|
for model_idx in range(self.model_num):
|
|
model = self.models[model_idx]
|
|
aggregator = self.aggregators[model_idx]
|
|
msg_list = list()
|
|
for client_id in train_msg_buffer:
|
|
if self.model_num == 1:
|
|
msg_list.append(
|
|
tuple(train_msg_buffer[client_id][0:2]))
|
|
else:
|
|
train_data_size, model_para_multiple = \
|
|
train_msg_buffer[client_id][0:2]
|
|
msg_list.append((train_data_size,
|
|
model_para_multiple[model_idx]))
|
|
|
|
# collect feedbacks for updating the policy
|
|
if model_idx == 0:
|
|
mab_feedbacks.append(
|
|
train_msg_buffer[client_id][2])
|
|
|
|
# Trigger the monitor here (for training)
|
|
self._monitor.calc_model_metric(self.model.state_dict(),
|
|
msg_list,
|
|
rnd=self.state)
|
|
|
|
# Aggregate
|
|
agg_info = {
|
|
'client_feedback': msg_list,
|
|
'recover_fun': self.recover_fun
|
|
}
|
|
result = aggregator.aggregate(agg_info)
|
|
model.load_state_dict(result, strict=False)
|
|
# aggregator.update(result)
|
|
|
|
# update the policy
|
|
self.update_policy(mab_feedbacks)
|
|
|
|
self.state += 1
|
|
if self.state % self._cfg.eval.freq == 0 and self.state != \
|
|
self.total_round_num:
|
|
# Evaluate
|
|
logger.info(
|
|
'Server: Starting evaluation at round {:d}.'.format(
|
|
self.state))
|
|
self.eval()
|
|
|
|
if self.state < self.total_round_num:
|
|
# Move to next round of training
|
|
logger.info(
|
|
f'----------- Starting a new training round (Round '
|
|
f'#{self.state}) -------------')
|
|
# Clean the msg_buffer
|
|
self.msg_buffer['train'][self.state - 1].clear()
|
|
|
|
self.broadcast_model_para(
|
|
msg_type='model_para',
|
|
sample_client_num=self.sample_client_num)
|
|
else:
|
|
# Final Evaluate
|
|
logger.info('Server: Training is finished! Starting '
|
|
'evaluation.')
|
|
self.eval()
|
|
|
|
else: # in the evaluation process
|
|
# Get all the message & aggregate
|
|
formatted_eval_res = self.merge_eval_results_from_all_clients()
|
|
self.history_results = merge_dict_of_results(
|
|
self.history_results, formatted_eval_res)
|
|
self.check_and_save()
|
|
else:
|
|
move_on_flag = False
|
|
|
|
return move_on_flag
|
|
|
|
def check_and_save(self):
|
|
"""
|
|
To save the results and save model after each evaluation
|
|
"""
|
|
# early stopping
|
|
should_stop = False
|
|
|
|
if "Results_weighted_avg" in self.history_results and \
|
|
self._cfg.eval.best_res_update_round_wise_key in \
|
|
self.history_results['Results_weighted_avg']:
|
|
should_stop = self.early_stopper.track_and_check(
|
|
self.history_results['Results_weighted_avg'][
|
|
self._cfg.eval.best_res_update_round_wise_key])
|
|
elif "Results_avg" in self.history_results and \
|
|
self._cfg.eval.best_res_update_round_wise_key in \
|
|
self.history_results['Results_avg']:
|
|
should_stop = self.early_stopper.track_and_check(
|
|
self.history_results['Results_avg'][
|
|
self._cfg.eval.best_res_update_round_wise_key])
|
|
else:
|
|
should_stop = False
|
|
|
|
if should_stop:
|
|
self.state = self.total_round_num + 1
|
|
|
|
if should_stop or self.state == self.total_round_num:
|
|
logger.info('Server: Final evaluation is finished! Starting '
|
|
'merging results.')
|
|
# last round
|
|
self.save_best_results()
|
|
|
|
if self._cfg.federate.save_to != '':
|
|
# save the policy
|
|
ckpt = dict()
|
|
if self._cfg.hpo.fedex.psn:
|
|
psn_pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.
|
|
federate.
|
|
save_to.
|
|
rfind(
|
|
'.'
|
|
)] + \
|
|
"_pfedex.pt"
|
|
torch.save(
|
|
{
|
|
'client_encodings': self._client_encodings,
|
|
'policy_net': self._policy_net.state_dict()
|
|
}, psn_pi_ckpt_path)
|
|
else:
|
|
z_list = [z.tolist() for z in self._z]
|
|
ckpt['z'] = z_list
|
|
ckpt['store'] = self._store
|
|
ckpt['stop'] = self._stop_exploration
|
|
ckpt['global'] = self.trace('global').tolist()
|
|
ckpt['refine'] = self.trace('refine').tolist()
|
|
ckpt['entropy'] = self.trace('entropy').tolist()
|
|
ckpt['mle'] = self.trace('mle').tolist()
|
|
pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.federate.
|
|
save_to.rfind(
|
|
'.'
|
|
)] + "_fedex.yaml"
|
|
with open(pi_ckpt_path, 'w') as ops:
|
|
yaml.dump(ckpt, ops)
|
|
|
|
if self.model_num > 1:
|
|
model_para = [model.state_dict() for model in self.models]
|
|
else:
|
|
model_para = self.model.state_dict()
|
|
self.comm_manager.send(
|
|
Message(msg_type='finish',
|
|
sender=self.ID,
|
|
receiver=list(self.comm_manager.neighbors.keys()),
|
|
state=self.state,
|
|
content=model_para))
|
|
|
|
if self.state == self.total_round_num:
|
|
# break out the loop for distributed mode
|
|
self.state += 1
|