import os import time import logging from os.path import join as osp import numpy as np import ConfigSpace as CS import hpbandster.core.nameserver as hpns from hpbandster.core.worker import Worker from hpbandster.optimizers import BOHB, HyperBand, RandomSearch from federatedscope.autotune.utils import eval_in_fs, log2wandb, \ summarize_hpo_results logging.basicConfig(level=logging.WARNING) logger = logging.getLogger(__name__) def clear_cache(working_folder): # Clear cached ckpt for name in os.listdir(working_folder): if name.endswith('.pth'): os.remove(osp(working_folder, name)) class MyRandomSearch(RandomSearch): def __init__(self, working_folder, **kwargs): self.working_folder = working_folder super(MyRandomSearch, self).__init__(**kwargs) class MyBOHB(BOHB): def __init__(self, working_folder, **kwargs): self.working_folder = working_folder super(MyBOHB, self).__init__(**kwargs) def get_next_iteration(self, iteration, iteration_kwargs={}): if os.path.exists(self.working_folder): clear_cache(self.working_folder) return super(MyBOHB, self).get_next_iteration(iteration, iteration_kwargs) class MyHyperBand(HyperBand): def __init__(self, working_folder, **kwargs): self.working_folder = working_folder super(MyHyperBand, self).__init__(**kwargs) def get_next_iteration(self, iteration, iteration_kwargs={}): if os.path.exists(self.working_folder): clear_cache(self.working_folder) return super(MyHyperBand, self).get_next_iteration(iteration, iteration_kwargs) class MyWorker(Worker): def __init__(self, cfg, ss, sleep_interval=0, client_cfgs=None, *args, **kwargs): super(MyWorker, self).__init__(**kwargs) self.sleep_interval = sleep_interval self.cfg = cfg self.client_cfgs = client_cfgs self._ss = ss self._init_configs = [] self._perfs = [] self.trial_index = 0 def compute(self, config, budget, **kwargs): results = eval_in_fs(self.cfg, config, int(budget), self.client_cfgs, self.trial_index) key1, key2 = self.cfg.hpo.metric.split('.') res = results[key1][key2] config = dict(config) config['federate.total_round_num'] = budget self._init_configs.append(config) self._perfs.append(float(res)) time.sleep(self.sleep_interval) logger.info(f'Evaluate the {len(self._perfs)-1}-th config ' f'{config}, and get performance {res}') if self.cfg.wandb.use: tmp_results = \ summarize_hpo_results(self._init_configs, self._perfs, white_list=set( self._ss.keys()), desc=self.cfg.hpo.larger_better, is_sorted=False) log2wandb( len(self._perfs) - 1, config, results, self.cfg, tmp_results) self.trial_index += 1 if self.cfg.hpo.larger_better: return {'loss': -float(res), 'info': res} else: return {'loss': float(res), 'info': res} def summarize(self): results = summarize_hpo_results(self._init_configs, self._perfs, white_list=set(self._ss.keys()), desc=self.cfg.hpo.larger_better, use_wandb=self.cfg.wandb.use) logger.info( "========================== HPO Final ==========================") logger.info("\n{}".format(results)) results.to_csv(os.path.join(self.cfg.hpo.working_folder, 'results.csv')) logger.info("====================================================") return results def run_hpbandster(cfg, scheduler, client_cfgs=None): config_space = scheduler._search_space if cfg.hpo.scheduler.startswith('wrap_'): ss = CS.ConfigurationSpace() ss.add_hyperparameter(config_space['hpo.table.idx']) config_space = ss NS = hpns.NameServer(run_id=cfg.hpo.scheduler, host='127.0.0.1', port=0) ns_host, ns_port = NS.start() w = MyWorker(sleep_interval=0, ss=config_space, cfg=cfg, nameserver='127.0.0.1', nameserver_port=ns_port, run_id=cfg.hpo.scheduler, client_cfgs=client_cfgs) w.run(background=True) opt_kwargs = { 'configspace': config_space, 'run_id': cfg.hpo.scheduler, 'nameserver': '127.0.0.1', 'nameserver_port': ns_port, 'eta': cfg.hpo.sha.elim_rate, 'min_budget': cfg.hpo.sha.budgets[0], 'max_budget': cfg.hpo.sha.budgets[-1], 'working_folder': cfg.hpo.working_folder } if cfg.hpo.scheduler in ['rs', 'wrap_rs']: optimizer = MyRandomSearch(**opt_kwargs) elif cfg.hpo.scheduler in ['hb', 'wrap_hb']: optimizer = MyHyperBand(**opt_kwargs) elif cfg.hpo.scheduler in ['bo_kde', 'bohb', 'wrap_bo_kde', 'wrap_bohb']: optimizer = MyBOHB(**opt_kwargs) else: raise ValueError if cfg.hpo.sha.iter != 0: n_iterations = cfg.hpo.sha.iter else: n_iterations = -int( np.log(opt_kwargs['min_budget'] / opt_kwargs['max_budget']) / np.log(opt_kwargs['eta'])) + 1 res = optimizer.run(n_iterations=n_iterations) optimizer.shutdown(shutdown_workers=True) NS.shutdown() all_runs = res.get_all_runs() w.summarize() return [x.info for x in all_runs]