107 lines
3.5 KiB
Python
107 lines
3.5 KiB
Python
import os
|
|
import logging
|
|
import numpy as np
|
|
import ConfigSpace as CS
|
|
from federatedscope.autotune.utils import eval_in_fs, log2wandb, \
|
|
summarize_hpo_results
|
|
from smac.facade.smac_bb_facade import SMAC4BB
|
|
from smac.facade.smac_hpo_facade import SMAC4HPO
|
|
from smac.scenario.scenario import Scenario
|
|
|
|
logging.basicConfig(level=logging.WARNING)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
trial_index = 0
|
|
|
|
|
|
def run_smac(cfg, scheduler, client_cfgs=None):
|
|
config_space = scheduler._search_space
|
|
init_configs = []
|
|
perfs = []
|
|
|
|
def optimization_function_wrapper(config):
|
|
"""
|
|
Used as an evaluation function for SMAC optimizer.
|
|
Args:
|
|
config: configurations of FS run.
|
|
Returns:
|
|
Best results of server of specific FS run.
|
|
"""
|
|
global trial_index
|
|
|
|
budget = cfg.hpo.sha.budgets[-1]
|
|
results = eval_in_fs(cfg, config, budget, client_cfgs, trial_index)
|
|
key1, key2 = cfg.hpo.metric.split('.')
|
|
res = results[key1][key2]
|
|
config = dict(config)
|
|
config['federate.total_round_num'] = budget
|
|
init_configs.append(config)
|
|
perfs.append(res)
|
|
logger.info(f'Evaluate the {len(perfs)-1}-th config '
|
|
f'{config}, and get performance {res}')
|
|
if cfg.wandb.use:
|
|
tmp_results = \
|
|
summarize_hpo_results(init_configs,
|
|
perfs,
|
|
white_list=set(config_space.keys()),
|
|
desc=cfg.hpo.larger_better,
|
|
is_sorted=False)
|
|
log2wandb(len(perfs) - 1, config, results, cfg, tmp_results)
|
|
trial_index += 1
|
|
|
|
if cfg.hpo.larger_better:
|
|
return -res
|
|
else:
|
|
return res
|
|
|
|
def summarize():
|
|
results = summarize_hpo_results(init_configs,
|
|
perfs,
|
|
white_list=set(config_space.keys()),
|
|
desc=cfg.hpo.larger_better,
|
|
use_wandb=cfg.wandb.use)
|
|
logger.info(
|
|
"========================== HPO Final ==========================")
|
|
logger.info("\n{}".format(results))
|
|
results.to_csv(os.path.join(cfg.hpo.working_folder, 'results.csv'))
|
|
logger.info("====================================================")
|
|
|
|
return perfs
|
|
|
|
if cfg.hpo.scheduler.startswith('wrap_'):
|
|
ss = CS.ConfigurationSpace()
|
|
ss.add_hyperparameter(config_space['hpo.table.idx'])
|
|
config_space = ss
|
|
|
|
if cfg.hpo.sha.iter != 0:
|
|
n_iterations = cfg.hpo.sha.iter
|
|
else:
|
|
n_iterations = -int(
|
|
np.log(cfg.hpo.sha.budgets[0] / cfg.hpo.sha.budgets[-1]) /
|
|
np.log(cfg.hpo.sha.elim_rate)) + 1
|
|
|
|
scenario = Scenario({
|
|
"run_obj": "quality",
|
|
"runcount-limit": n_iterations,
|
|
"cs": config_space,
|
|
"output_dir": cfg.hpo.working_folder,
|
|
"deterministic": "true",
|
|
"limit_resources": False
|
|
})
|
|
|
|
if cfg.hpo.scheduler.endswith('bo_gp'):
|
|
smac = SMAC4BB(model_type='gp',
|
|
scenario=scenario,
|
|
tae_runner=optimization_function_wrapper)
|
|
elif cfg.hpo.scheduler.endswith('bo_rf'):
|
|
smac = SMAC4HPO(scenario=scenario,
|
|
tae_runner=optimization_function_wrapper)
|
|
else:
|
|
raise NotImplementedError
|
|
try:
|
|
smac.optimize()
|
|
finally:
|
|
smac.solver.incumbent
|
|
summarize()
|
|
return perfs
|