166 lines
5.9 KiB
Python
166 lines
5.9 KiB
Python
import os
|
|
import time
|
|
import logging
|
|
|
|
from os.path import join as osp
|
|
import numpy as np
|
|
import ConfigSpace as CS
|
|
import hpbandster.core.nameserver as hpns
|
|
from hpbandster.core.worker import Worker
|
|
from hpbandster.optimizers import BOHB, HyperBand, RandomSearch
|
|
|
|
from federatedscope.autotune.utils import eval_in_fs, log2wandb, \
|
|
summarize_hpo_results
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def clear_cache(working_folder):
|
|
# Clear cached ckpt
|
|
for name in os.listdir(working_folder):
|
|
if name.endswith('.pth'):
|
|
os.remove(osp(working_folder, name))
|
|
|
|
|
|
class MyRandomSearch(RandomSearch):
|
|
def __init__(self, working_folder, **kwargs):
|
|
self.working_folder = working_folder
|
|
super(MyRandomSearch, self).__init__(**kwargs)
|
|
|
|
|
|
class MyBOHB(BOHB):
|
|
def __init__(self, working_folder, **kwargs):
|
|
self.working_folder = working_folder
|
|
super(MyBOHB, self).__init__(**kwargs)
|
|
|
|
def get_next_iteration(self, iteration, iteration_kwargs={}):
|
|
if os.path.exists(self.working_folder):
|
|
clear_cache(self.working_folder)
|
|
return super(MyBOHB, self).get_next_iteration(iteration,
|
|
iteration_kwargs)
|
|
|
|
|
|
class MyHyperBand(HyperBand):
|
|
def __init__(self, working_folder, **kwargs):
|
|
self.working_folder = working_folder
|
|
super(MyHyperBand, self).__init__(**kwargs)
|
|
|
|
def get_next_iteration(self, iteration, iteration_kwargs={}):
|
|
if os.path.exists(self.working_folder):
|
|
clear_cache(self.working_folder)
|
|
return super(MyHyperBand,
|
|
self).get_next_iteration(iteration, iteration_kwargs)
|
|
|
|
|
|
class MyWorker(Worker):
|
|
def __init__(self,
|
|
cfg,
|
|
ss,
|
|
sleep_interval=0,
|
|
client_cfgs=None,
|
|
*args,
|
|
**kwargs):
|
|
super(MyWorker, self).__init__(**kwargs)
|
|
self.sleep_interval = sleep_interval
|
|
self.cfg = cfg
|
|
self.client_cfgs = client_cfgs
|
|
self._ss = ss
|
|
self._init_configs = []
|
|
self._perfs = []
|
|
self.trial_index = 0
|
|
|
|
def compute(self, config, budget, **kwargs):
|
|
results = eval_in_fs(self.cfg, config, int(budget), self.client_cfgs,
|
|
self.trial_index)
|
|
key1, key2 = self.cfg.hpo.metric.split('.')
|
|
res = results[key1][key2]
|
|
config = dict(config)
|
|
config['federate.total_round_num'] = budget
|
|
self._init_configs.append(config)
|
|
self._perfs.append(float(res))
|
|
time.sleep(self.sleep_interval)
|
|
logger.info(f'Evaluate the {len(self._perfs)-1}-th config '
|
|
f'{config}, and get performance {res}')
|
|
if self.cfg.wandb.use:
|
|
tmp_results = \
|
|
summarize_hpo_results(self._init_configs,
|
|
self._perfs,
|
|
white_list=set(
|
|
self._ss.keys()),
|
|
desc=self.cfg.hpo.larger_better,
|
|
is_sorted=False)
|
|
log2wandb(
|
|
len(self._perfs) - 1, config, results, self.cfg, tmp_results)
|
|
self.trial_index += 1
|
|
|
|
if self.cfg.hpo.larger_better:
|
|
return {'loss': -float(res), 'info': res}
|
|
else:
|
|
return {'loss': float(res), 'info': res}
|
|
|
|
def summarize(self):
|
|
results = summarize_hpo_results(self._init_configs,
|
|
self._perfs,
|
|
white_list=set(self._ss.keys()),
|
|
desc=self.cfg.hpo.larger_better,
|
|
use_wandb=self.cfg.wandb.use)
|
|
logger.info(
|
|
"========================== HPO Final ==========================")
|
|
logger.info("\n{}".format(results))
|
|
results.to_csv(os.path.join(self.cfg.hpo.working_folder,
|
|
'results.csv'))
|
|
logger.info("====================================================")
|
|
|
|
return results
|
|
|
|
|
|
def run_hpbandster(cfg, scheduler, client_cfgs=None):
|
|
config_space = scheduler._search_space
|
|
if cfg.hpo.scheduler.startswith('wrap_'):
|
|
ss = CS.ConfigurationSpace()
|
|
ss.add_hyperparameter(config_space['hpo.table.idx'])
|
|
config_space = ss
|
|
NS = hpns.NameServer(run_id=cfg.hpo.scheduler, host='127.0.0.1', port=0)
|
|
ns_host, ns_port = NS.start()
|
|
w = MyWorker(sleep_interval=0,
|
|
ss=config_space,
|
|
cfg=cfg,
|
|
nameserver='127.0.0.1',
|
|
nameserver_port=ns_port,
|
|
run_id=cfg.hpo.scheduler,
|
|
client_cfgs=client_cfgs)
|
|
w.run(background=True)
|
|
opt_kwargs = {
|
|
'configspace': config_space,
|
|
'run_id': cfg.hpo.scheduler,
|
|
'nameserver': '127.0.0.1',
|
|
'nameserver_port': ns_port,
|
|
'eta': cfg.hpo.sha.elim_rate,
|
|
'min_budget': cfg.hpo.sha.budgets[0],
|
|
'max_budget': cfg.hpo.sha.budgets[-1],
|
|
'working_folder': cfg.hpo.working_folder
|
|
}
|
|
if cfg.hpo.scheduler in ['rs', 'wrap_rs']:
|
|
optimizer = MyRandomSearch(**opt_kwargs)
|
|
elif cfg.hpo.scheduler in ['hb', 'wrap_hb']:
|
|
optimizer = MyHyperBand(**opt_kwargs)
|
|
elif cfg.hpo.scheduler in ['bo_kde', 'bohb', 'wrap_bo_kde', 'wrap_bohb']:
|
|
optimizer = MyBOHB(**opt_kwargs)
|
|
else:
|
|
raise ValueError
|
|
|
|
if cfg.hpo.sha.iter != 0:
|
|
n_iterations = cfg.hpo.sha.iter
|
|
else:
|
|
n_iterations = -int(
|
|
np.log(opt_kwargs['min_budget'] / opt_kwargs['max_budget']) /
|
|
np.log(opt_kwargs['eta'])) + 1
|
|
res = optimizer.run(n_iterations=n_iterations)
|
|
optimizer.shutdown(shutdown_workers=True)
|
|
NS.shutdown()
|
|
all_runs = res.get_all_runs()
|
|
w.summarize()
|
|
|
|
return [x.info for x in all_runs]
|