49 lines
1.9 KiB
Python
49 lines
1.9 KiB
Python
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
import unittest
|
|
|
|
from federatedscope.core.auxiliaries.data_builder import get_data
|
|
from federatedscope.core.auxiliaries.utils import setup_seed
|
|
from federatedscope.core.auxiliaries.logging import update_logger
|
|
from federatedscope.core.configs.config import global_cfg
|
|
from federatedscope.core.auxiliaries.runner_builder import get_runner
|
|
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls
|
|
|
|
|
|
class EfficientSimulationTest(unittest.TestCase):
|
|
def setUp(self):
|
|
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
|
|
|
|
def test_toy_example_standalone_cmp_sim_impl(self):
|
|
case_cfg = global_cfg.clone()
|
|
case_cfg.merge_from_file('scripts/example_configs/single_process.yaml')
|
|
|
|
setup_seed(case_cfg.seed)
|
|
update_logger(case_cfg)
|
|
|
|
data, _ = get_data(case_cfg.clone())
|
|
Fed_runner = get_runner(data=data,
|
|
server_class=get_server_cls(case_cfg),
|
|
client_class=get_client_cls(case_cfg),
|
|
config=case_cfg.clone())
|
|
efficient_test_results = Fed_runner.run()
|
|
|
|
setup_seed(case_cfg.seed)
|
|
case_cfg.merge_from_list([
|
|
'federate.share_local_model', 'False', 'federate.online_aggr',
|
|
'False'
|
|
])
|
|
data, _ = get_data(case_cfg.clone())
|
|
Fed_runner = get_runner(data=data,
|
|
server_class=get_server_cls(case_cfg),
|
|
client_class=get_client_cls(case_cfg),
|
|
config=case_cfg.clone())
|
|
ordinary_test_results = Fed_runner.run()
|
|
gap = efficient_test_results["client_summarized_weighted_avg"][
|
|
'test_loss'] - ordinary_test_results[
|
|
"client_summarized_weighted_avg"]['test_loss']
|
|
self.assertLess(gap, 0.1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|