FS-TFP/tests/test_global_train_lr.py

57 lines
2.0 KiB
Python

# Copyright (c) Alibaba, Inc. and its affiliates.
import unittest
from federatedscope.core.auxiliaries.data_builder import get_data
from federatedscope.core.auxiliaries.utils import setup_seed
from federatedscope.core.auxiliaries.logging import update_logger
from federatedscope.core.configs.config import global_cfg
from federatedscope.core.auxiliaries.runner_builder import get_runner
from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls
class ToyLRTest(unittest.TestCase):
def setUp(self):
print(('Testing %s.%s' % (type(self).__name__, self._testMethodName)))
def set_config_standalone(self, cfg, make_global_eval=False):
import torch
cfg.use_gpu = torch.cuda.is_available()
cfg.federate.mode = 'standalone'
cfg.federate.total_round_num = 30
cfg.federate.make_global_eval = make_global_eval
cfg.federate.client_num = 5
cfg.eval.freq = 10
cfg.data.type = 'toy'
cfg.trainer.type = 'general'
cfg.model.type = 'lr'
cfg.early_stop.patience = 5
cfg.federate.method = "global"
def test_toy_example_standalone(self):
init_cfg = global_cfg.clone()
self.set_config_standalone(init_cfg)
setup_seed(init_cfg.seed)
update_logger(init_cfg, True)
data, modified_config = get_data(init_cfg.clone())
init_cfg.merge_from_other_cfg(modified_config)
self.assertIsNotNone(data)
Fed_runner = get_runner(data=data,
server_class=get_server_cls(init_cfg),
client_class=get_client_cls(init_cfg),
config=init_cfg.clone())
self.assertIsNotNone(Fed_runner)
test_best_results = Fed_runner.run()
print(test_best_results)
self.assertLess(
test_best_results["client_summarized_weighted_avg"]['test_loss'],
0.3)
if __name__ == '__main__':
unittest.main()