FS-TFP/federatedscope/autotune/fedex/server.py

534 lines
22 KiB
Python

import os
import logging
from itertools import product
import yaml
import numpy as np
from numpy.linalg import norm
from scipy.special import logsumexp
import torch
from federatedscope.core.message import Message
from federatedscope.core.workers import Server
from federatedscope.core.auxiliaries.utils import merge_dict_of_results
from federatedscope.autotune.fedex.utils import HyperNet
logger = logging.getLogger(__name__)
def discounted_mean(trace, factor=1.0):
weight = factor**np.flip(np.arange(len(trace)), axis=0)
return np.inner(trace, weight) / weight.sum()
class FedExServer(Server):
"""Some code snippets are borrowed from the open-sourced FedEx (
https://github.com/mkhodak/FedEx)
"""
def __init__(self,
ID=-1,
state=0,
config=None,
data=None,
model=None,
client_num=5,
total_round_num=10,
device='cpu',
strategy=None,
**kwargs):
super(FedExServer,
self).__init__(ID, state, config, data, model, client_num,
total_round_num, device, strategy, **kwargs)
# initialize action space and the policy
with open(config.hpo.fedex.ss, 'r') as ips:
ss = yaml.load(ips, Loader=yaml.FullLoader)
if next(iter(ss.keys())).startswith('arm'):
# This is a flattened action space
# ensure the order is unchanged
ss = sorted([(int(k[3:]), v) for k, v in ss.items()],
key=lambda x: x[0])
self._grid = []
self._cfsp = [[tp[1] for tp in ss]]
else:
# This is not a flat search space
# be careful for the order
self._grid = sorted(ss.keys())
self._cfsp = [ss[pn] for pn in self._grid]
sizes = [len(cand_set) for cand_set in self._cfsp]
eta0 = 'auto' if config.hpo.fedex.eta0 <= .0 else float(
config.hpo.fedex.eta0)
self._eta0 = [
np.sqrt(2.0 * np.log(size)) if eta0 == 'auto' else eta0
for size in sizes
]
self._sched = config.hpo.fedex.sched
self._cutoff = config.hpo.fedex.cutoff
self._baseline = config.hpo.fedex.gamma
self._diff = config.hpo.fedex.diff
if self._cfg.hpo.fedex.psn:
# personalized policy
# TODO: client-wise RFF
self._client_encodings = torch.randn(
(client_num, 8), device=device) / np.sqrt(8)
self._policy_net = HyperNet(
self._client_encodings.shape[-1],
sizes,
client_num,
device,
).to(device)
self._policy_net.eval()
theta4stat = [
theta.detach().cpu().numpy()
for theta in self._policy_net(self._client_encodings)
]
self._pn_optimizer = torch.optim.Adam(
self._policy_net.parameters(),
lr=self._cfg.hpo.fedex.pi_lr,
weight_decay=1e-5)
else:
self._z = [np.full(size, -np.log(size)) for size in sizes]
self._theta = [np.exp(z) for z in self._z]
theta4stat = self._theta
self._store = [0.0 for _ in sizes]
self._stop_exploration = False
self._trace = {
'global': [],
'refine': [],
'entropy': [self.entropy(theta4stat)],
'mle': [self.mle(theta4stat)]
}
if self._cfg.federate.restore_from != '':
if not os.path.exists(self._cfg.federate.restore_from):
logger.warning(f'Invalid `restore_from`:'
f' {self._cfg.federate.restore_from}.')
else:
pi_ckpt_path = self._cfg.federate.restore_from[
:self._cfg.federate.restore_from.rfind('.')] \
+ "_fedex.yaml"
with open(pi_ckpt_path, 'r') as ips:
ckpt = yaml.load(ips, Loader=yaml.FullLoader)
if self._cfg.hpo.fedex.psn:
psn_pi_ckpt_path = self._cfg.federate.restore_from[
:self._cfg.federate.restore_from.rfind('.')] \
+ "_pfedex.pt"
psn_pi = torch.load(psn_pi_ckpt_path, map_location=device)
self._client_encodings = psn_pi['client_encodings']
self._policy_net.load_state_dict(psn_pi['policy_net'])
else:
self._z = [np.asarray(z) for z in ckpt['z']]
self._theta = [np.exp(z) for z in self._z]
self._store = ckpt['store']
self._stop_exploration = ckpt['stop']
self._trace = dict()
self._trace['global'] = ckpt['global']
self._trace['refine'] = ckpt['refine']
self._trace['entropy'] = ckpt['entropy']
self._trace['mle'] = ckpt['mle']
def entropy(self, thetas):
if self._cfg.hpo.fedex.psn:
entropy = 0.0
for i in range(thetas[0].shape[0]):
for probs in product(*(theta[i][theta[i] > 0.0]
for theta in thetas)):
prob = np.prod(probs)
entropy -= prob * np.log(prob)
return entropy / float(thetas[0].shape[0])
else:
entropy = 0.0
for probs in product(*(theta[theta > 0.0] for theta in thetas)):
prob = np.prod(probs)
entropy -= prob * np.log(prob)
return entropy
def mle(self, thetas):
if self._cfg.hpo.fedex.psn:
return np.prod([theta.max(-1) for theta in thetas], 0).mean()
else:
return np.prod([theta.max() for theta in thetas])
def trace(self, key):
'''returns trace of one of three tracked quantities
Args:
key (str): 'entropy', 'global', or 'refine'
Returns:
numpy vector with length equal to number of rounds up to now.
'''
return np.array(self._trace[key])
def sample(self, thetas):
"""samples from configs using current probability vector
Arguments:
thetas (list): probabilities for the hyperparameters.
"""
# determine index
if self._stop_exploration:
cfg_idx = [int(theta.argmax()) for theta in thetas]
else:
cfg_idx = [
np.random.choice(len(theta), p=theta) for theta in thetas
]
# get the sampled value(s)
if self._grid:
sampled_cfg = {
pn: cands[i]
for pn, cands, i in zip(self._grid, self._cfsp, cfg_idx)
}
else:
sampled_cfg = self._cfsp[0][cfg_idx[0]]
return cfg_idx, sampled_cfg
def broadcast_model_para(self,
msg_type='model_para',
sample_client_num=-1,
filter_unseen_clients=True):
"""
To broadcast the message to all clients or sampled clients
"""
if filter_unseen_clients:
# to filter out the unseen clients when sampling
self.sampler.change_state(self.unseen_clients_id, 'unseen')
if sample_client_num > 0:
receiver = self.sampler.sample(size=sample_client_num)
else:
# broadcast to all clients
receiver = list(self.comm_manager.neighbors.keys())
if msg_type == 'model_para':
self.sampler.change_state(receiver, 'working')
if self._noise_injector is not None and msg_type == 'model_para':
# Inject noise only when broadcast parameters
for model_idx_i in range(len(self.models)):
num_sample_clients = [
v["num_sample"] for v in self.join_in_info.values()
]
self._noise_injector(self._cfg, num_sample_clients,
self.models[model_idx_i])
if self.model_num > 1:
model_para = [model.state_dict() for model in self.models]
else:
model_para = self.model.state_dict()
# sample the hyper-parameter config specific to the clients
if self._cfg.hpo.fedex.psn:
self._policy_net.train()
self._pn_optimizer.zero_grad()
self._theta = self._policy_net(self._client_encodings)
for rcv_idx in receiver:
if self._cfg.hpo.fedex.psn:
cfg_idx, sampled_cfg = self.sample([
theta[rcv_idx - 1].detach().cpu().numpy()
for theta in self._theta
])
else:
cfg_idx, sampled_cfg = self.sample(self._theta)
content = {
'model_param': model_para,
"arms": cfg_idx,
'hyperparam': sampled_cfg
}
self.comm_manager.send(
Message(msg_type=msg_type,
sender=self.ID,
receiver=[rcv_idx],
state=self.state,
content=content))
if self._cfg.federate.online_aggr:
for idx in range(self.model_num):
self.aggregators[idx].reset()
if filter_unseen_clients:
# restore the state of the unseen clients within sampler
self.sampler.change_state(self.unseen_clients_id, 'seen')
def callback_funcs_model_para(self, message: Message):
round, sender, content = message.state, message.sender, message.content
self.sampler.change_state(sender, 'idle')
# For a new round
if round not in self.msg_buffer['train'].keys():
self.msg_buffer['train'][round] = dict()
self.msg_buffer['train'][round][sender] = content
if self._cfg.federate.online_aggr:
self.aggregator.inc(tuple(content[0:2]))
return self.check_and_move_on()
def update_policy(self, feedbacks):
"""Update the policy. This implementation is borrowed from the
open-sourced FedEx (
https://github.com/mkhodak/FedEx/blob/ \
150fac03857a3239429734d59d319da71191872e/hyper.py#L151)
Arguments:
feedbacks (list): each element is a dict containing "arms" and
necessary feedback.
"""
index = [elem['arms'] for elem in feedbacks]
cids = [elem['client_id'] for elem in feedbacks]
before = np.asarray(
[elem['val_avg_loss_before'] for elem in feedbacks])
after = np.asarray([elem['val_avg_loss_after'] for elem in feedbacks])
weight = np.asarray([elem['val_total'] for elem in feedbacks],
dtype=np.float64)
weight /= np.sum(weight)
if self._trace['refine']:
trace = self.trace('refine')
if self._diff:
trace -= self.trace('global')
baseline = discounted_mean(trace, self._baseline)
else:
baseline = 0.0
self._trace['global'].append(np.inner(before, weight))
self._trace['refine'].append(np.inner(after, weight))
if self._stop_exploration:
self._trace['entropy'].append(0.0)
self._trace['mle'].append(1.0)
return
if self._cfg.hpo.fedex.psn:
# policy gradients
pg_obj = .0
for i, theta in enumerate(self._theta):
for idx, cidx, s, w in zip(
index, cids, after - before if self._diff else after,
weight):
pg_obj += w * -1.0 * (s - baseline) * torch.log(
torch.clip(theta[cidx][idx[i]], min=1e-8, max=1.0))
pg_loss = -1.0 * pg_obj
pg_loss.backward()
self._pn_optimizer.step()
self._policy_net.eval()
thetas4stat = [
theta.detach().cpu().numpy()
for theta in self._policy_net(self._client_encodings)
]
else:
for i, (z, theta) in enumerate(zip(self._z, self._theta)):
grad = np.zeros(len(z))
for idx, s, w in zip(index,
after - before if self._diff else after,
weight):
grad[idx[i]] += w * (s - baseline) / theta[idx[i]]
if self._sched == 'adaptive':
self._store[i] += norm(grad, float('inf'))**2
denom = np.sqrt(self._store[i])
elif self._sched == 'aggressive':
denom = 1.0 if np.all(
grad == 0.0) else norm(grad, float('inf'))
elif self._sched == 'auto':
self._store[i] += 1.0
denom = np.sqrt(self._store[i])
elif self._sched == 'constant':
denom = 1.0
elif self._sched == 'scale':
denom = 1.0 / np.sqrt(2.0 * np.log(len(grad))) if len(
grad) > 1 else float('inf')
else:
raise NotImplementedError
eta = self._eta0[i] / denom
z -= eta * grad
z -= logsumexp(z)
self._theta[i] = np.exp(z)
thetas4stat = self._theta
self._trace['entropy'].append(self.entropy(thetas4stat))
self._trace['mle'].append(self.mle(thetas4stat))
if self._trace['entropy'][-1] < self._cutoff:
self._stop_exploration = True
logger.info(
'Server: Updated policy as {} with entropy {:f} and mle {:f}'.
format(thetas4stat, self._trace['entropy'][-1],
self._trace['mle'][-1]))
def check_and_move_on(self,
check_eval_result=False,
min_received_num=None):
"""
To check the message_buffer, when enough messages are receiving,
trigger some events (such as perform aggregation, evaluation,
and move to the next training round)
"""
if min_received_num is None:
min_received_num = self._cfg.federate.sample_client_num
assert min_received_num <= self.sample_client_num
if check_eval_result:
min_received_num = len(list(self.comm_manager.neighbors.keys()))
move_on_flag = True # To record whether moving to a new training
# round or finishing the evaluation
if self.check_buffer(self.state, min_received_num, check_eval_result):
if not check_eval_result: # in the training process
mab_feedbacks = list()
# Get all the message
train_msg_buffer = self.msg_buffer['train'][self.state]
for model_idx in range(self.model_num):
model = self.models[model_idx]
aggregator = self.aggregators[model_idx]
msg_list = list()
for client_id in train_msg_buffer:
if self.model_num == 1:
msg_list.append(
tuple(train_msg_buffer[client_id][0:2]))
else:
train_data_size, model_para_multiple = \
train_msg_buffer[client_id][0:2]
msg_list.append((train_data_size,
model_para_multiple[model_idx]))
# collect feedbacks for updating the policy
if model_idx == 0:
mab_feedbacks.append(
train_msg_buffer[client_id][2])
# Trigger the monitor here (for training)
self._monitor.calc_model_metric(self.model.state_dict(),
msg_list,
rnd=self.state)
# Aggregate
agg_info = {
'client_feedback': msg_list,
'recover_fun': self.recover_fun
}
result = aggregator.aggregate(agg_info)
model.load_state_dict(result, strict=False)
# aggregator.update(result)
# update the policy
self.update_policy(mab_feedbacks)
self.state += 1
if self.state % self._cfg.eval.freq == 0 and self.state != \
self.total_round_num:
# Evaluate
logger.info(
'Server: Starting evaluation at round {:d}.'.format(
self.state))
self.eval()
if self.state < self.total_round_num:
# Move to next round of training
logger.info(
f'----------- Starting a new training round (Round '
f'#{self.state}) -------------')
# Clean the msg_buffer
self.msg_buffer['train'][self.state - 1].clear()
self.broadcast_model_para(
msg_type='model_para',
sample_client_num=self.sample_client_num)
else:
# Final Evaluate
logger.info('Server: Training is finished! Starting '
'evaluation.')
self.eval()
else: # in the evaluation process
# Get all the message & aggregate
formatted_eval_res = self.merge_eval_results_from_all_clients()
self.history_results = merge_dict_of_results(
self.history_results, formatted_eval_res)
self.check_and_save()
else:
move_on_flag = False
return move_on_flag
def check_and_save(self):
"""
To save the results and save model after each evaluation
"""
# early stopping
should_stop = False
if "Results_weighted_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_weighted_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_weighted_avg'][
self._cfg.eval.best_res_update_round_wise_key])
elif "Results_avg" in self.history_results and \
self._cfg.eval.best_res_update_round_wise_key in \
self.history_results['Results_avg']:
should_stop = self.early_stopper.track_and_check(
self.history_results['Results_avg'][
self._cfg.eval.best_res_update_round_wise_key])
else:
should_stop = False
if should_stop:
self.state = self.total_round_num + 1
if should_stop or self.state == self.total_round_num:
logger.info('Server: Final evaluation is finished! Starting '
'merging results.')
# last round
self.save_best_results()
if self._cfg.federate.save_to != '':
# save the policy
ckpt = dict()
if self._cfg.hpo.fedex.psn:
psn_pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.
federate.
save_to.
rfind(
'.'
)] + \
"_pfedex.pt"
torch.save(
{
'client_encodings': self._client_encodings,
'policy_net': self._policy_net.state_dict()
}, psn_pi_ckpt_path)
else:
z_list = [z.tolist() for z in self._z]
ckpt['z'] = z_list
ckpt['store'] = self._store
ckpt['stop'] = self._stop_exploration
ckpt['global'] = self.trace('global').tolist()
ckpt['refine'] = self.trace('refine').tolist()
ckpt['entropy'] = self.trace('entropy').tolist()
ckpt['mle'] = self.trace('mle').tolist()
pi_ckpt_path = self._cfg.federate.save_to[:self._cfg.federate.
save_to.rfind(
'.'
)] + "_fedex.yaml"
with open(pi_ckpt_path, 'w') as ops:
yaml.dump(ckpt, ops)
if self.model_num > 1:
model_para = [model.state_dict() for model in self.models]
else:
model_para = self.model.state_dict()
self.comm_manager.send(
Message(msg_type='finish',
sender=self.ID,
receiver=list(self.comm_manager.neighbors.keys()),
state=self.state,
content=model_para))
if self.state == self.total_round_num:
# break out the loop for distributed mode
self.state += 1