FS-TFP/tests/test_efficient_simulation.py

49 lines
1.9 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.logging import update_logger
from federatedscope.core.configs.config import global_cfg
from federatedscope.core.auxiliaries.runner_builder import get_runner
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls
class EfficientSimulationTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def test_toy_example_standalone_cmp_sim_impl(self):
case_cfg = global_cfg.clone()
case_cfg.merge_from_file('scripts/example_configs/single_process.yaml')
setup_seed(case_cfg.seed)
update_logger(case_cfg)
data, _ = get_data(case_cfg.clone())
Fed_runner = get_runner(data=data,
server_class=get_server_cls(case_cfg),
client_class=get_client_cls(case_cfg),
config=case_cfg.clone())
efficient_test_results = Fed_runner.run()
setup_seed(case_cfg.seed)
case_cfg.merge_from_list([
'federate.share_local_model', 'False', 'federate.online_aggr',
'False'
])
data, _ = get_data(case_cfg.clone())
Fed_runner = get_runner(data=data,
server_class=get_server_cls(case_cfg),
client_class=get_client_cls(case_cfg),
config=case_cfg.clone())
ordinary_test_results = Fed_runner.run()
gap = efficient_test_results["client_summarized_weighted_avg"][
'test_loss'] - ordinary_test_results[
"client_summarized_weighted_avg"]['test_loss']
self.assertLess(gap, 0.1)
if __name__ == '__main__':
unittest.main()