83 lines
3.1 KiB
Python
83 lines
3.1 KiB
Python
import datetime
|
|
from federatedscope.core.auxiliaries.utils import setup_seed
|
|
from federatedscope.core.auxiliaries.data_builder import get_data
|
|
from federatedscope.core.auxiliaries.worker_builder import get_client_cls, \
|
|
get_server_cls
|
|
from federatedscope.core.fed_runner import FedRunner
|
|
|
|
from fedhpobench.benchmarks.base_benchmark import BaseBenchmark
|
|
from fedhpobench.utils.util import disable_fs_logger
|
|
from fedhpobench.utils.cost_model import merge_cfg
|
|
|
|
|
|
class RawBenchmark(BaseBenchmark):
|
|
def __init__(self,
|
|
model,
|
|
dname,
|
|
algo,
|
|
rng=None,
|
|
cost_mode='estimated',
|
|
**kwargs):
|
|
super(RawBenchmark, self).__init__(model, dname, algo, cost_mode, rng,
|
|
**kwargs)
|
|
self.device = kwargs['device']
|
|
|
|
def _run_fl(self, configuration, fidelity, key='val_avg_loss', seed=1):
|
|
init_cfg = self.cfg.clone()
|
|
disable_fs_logger(init_cfg, True)
|
|
setup_seed(seed)
|
|
modified_cfg = merge_cfg(init_cfg, configuration, fidelity)
|
|
data, modified_cfg = get_data(modified_cfg.clone())
|
|
init_cfg.merge_from_other_cfg(modified_cfg)
|
|
init_cfg.device = self.device
|
|
if self.algo == 'opt':
|
|
init_cfg.federate.share_local_model = False
|
|
init_cfg.federate.online_aggr = False
|
|
init_cfg.fedopt.use = True
|
|
init_cfg.federate.method = 'FedOpt'
|
|
init_cfg.freeze()
|
|
runner = FedRunner(data=data,
|
|
server_class=get_server_cls(init_cfg),
|
|
client_class=get_client_cls(init_cfg),
|
|
config=init_cfg.clone())
|
|
results = runner.run()
|
|
# so that we could modify cfg in the next trial
|
|
init_cfg.defrost()
|
|
if 'server_global_eval' in results:
|
|
return [results['server_global_eval'][key]]
|
|
else:
|
|
return [results['client_summarized_weighted_avg'][key]]
|
|
|
|
def objective_function(self,
|
|
configuration,
|
|
fidelity=None,
|
|
key='val_avg_loss',
|
|
seed=1,
|
|
**kwargs):
|
|
fidelity = self._init_fidelity(fidelity)
|
|
self._check(configuration, fidelity)
|
|
start_time = datetime.datetime.now()
|
|
function_value = self._run_fl(configuration, fidelity, key, seed)[0]
|
|
end_time = datetime.datetime.now()
|
|
if self._cost(configuration, fidelity, **kwargs):
|
|
cost = self._cost(configuration, fidelity, **kwargs)
|
|
else:
|
|
# TODO: use time from FS monitor
|
|
cost = end_time - start_time
|
|
|
|
return {'function_value': function_value, 'cost': cost}
|
|
|
|
def get_configuration_space(self):
|
|
return []
|
|
|
|
def get_fidelity_space(self):
|
|
return []
|
|
|
|
def get_meta_info(self):
|
|
return {
|
|
'model': self.model,
|
|
'dname': self.dname,
|
|
'configuration_space': self.configuration_space,
|
|
'fidelity_space': self.fidelity_space
|
|
}
|