88 lines
3.4 KiB
Python
88 lines
3.4 KiB
Python
import os
|
|
|
|
from fedhpobench.benchmarks.base_benchmark import BaseBenchmark
|
|
from fedhpobench.utils.surrogate_dataloader import build_surrogate_model, \
|
|
load_surrogate_model
|
|
|
|
|
|
class SurrogateBenchmark(BaseBenchmark):
|
|
def __init__(self,
|
|
model,
|
|
dname,
|
|
algo,
|
|
modeldir='data/surrogate_model/',
|
|
datadir='data/tabular_data/',
|
|
rng=None,
|
|
cost_mode='estimated',
|
|
**kwargs):
|
|
self.model, self.dname, self.algo, self.cost_mode = model, dname, \
|
|
algo, cost_mode
|
|
assert datadir or modeldir, 'Please provide at least one of ' \
|
|
'`datadir` and `modeldir`.'
|
|
if not os.path.exists(os.path.join(modeldir, model, dname, algo)):
|
|
self.surrogate_models, self.meta_info, self.X, self.Y = \
|
|
build_surrogate_model(datadir, model, dname, algo)
|
|
else:
|
|
self.surrogate_models, self.meta_info, self.X, self.Y = \
|
|
load_surrogate_model(modeldir, model, dname, algo)
|
|
super(SurrogateBenchmark, self).__init__(model, dname, algo, cost_mode,
|
|
rng, **kwargs)
|
|
|
|
def _check(self, configuration, fidelity):
|
|
for key in configuration:
|
|
assert key in self.configuration_space, 'configuration invalid, ' \
|
|
'check ' \
|
|
'`configuration_space` ' \
|
|
'for help.'
|
|
for key in fidelity:
|
|
assert key in self.fidelity_space, 'fidelity invalid, ' \
|
|
'check `fidelity_space` for ' \
|
|
'help.'
|
|
|
|
def _make_prediction(self, configuration, fidelity, seed):
|
|
model = self.surrogate_models[self.rng.randint(seed) %
|
|
len(self.surrogate_models)]
|
|
x_in = []
|
|
cfg_keys = sorted(self.configuration_space)
|
|
fid_keys = sorted(self.fidelity_space)
|
|
for key in cfg_keys:
|
|
x_in.append(configuration[key])
|
|
for key in fid_keys:
|
|
x_in.append(fidelity[key])
|
|
return model.predict([x_in])[0]
|
|
|
|
# noinspection DuplicatedCode
|
|
def objective_function(self,
|
|
configuration,
|
|
fidelity=None,
|
|
seed=1,
|
|
**kwargs):
|
|
fidelity = self._init_fidelity(fidelity)
|
|
self._check(configuration, fidelity)
|
|
return {
|
|
'function_value': self._make_prediction(configuration, fidelity,
|
|
seed),
|
|
'cost': self._cost(configuration, fidelity, **kwargs)
|
|
}
|
|
|
|
def get_configuration_space(self):
|
|
new_list = []
|
|
for i in self.meta_info['configuration_space']:
|
|
if i == 'batch':
|
|
new_list.append('batch_size')
|
|
else:
|
|
new_list.append(i)
|
|
|
|
return new_list
|
|
|
|
def get_fidelity_space(self):
|
|
return self.meta_info['fidelity_space']
|
|
|
|
def get_meta_info(self):
|
|
return {
|
|
'model': self.model,
|
|
'dname': self.dname,
|
|
'configuration_space': self.configuration_space,
|
|
'fidelity_space': self.fidelity_space
|
|
}
|