# Copyright (c) Alibaba, Inc. and its affiliates. import unittest from federatedscope.core.auxiliaries.data_builder import get_data from federatedscope.core.auxiliaries.utils import setup_seed from federatedscope.core.auxiliaries.logging import update_logger from federatedscope.core.configs.config import global_cfg from federatedscope.core.auxiliaries.runner_builder import get_runner from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls class robust_aggr_AlgoTest(unittest.TestCase): def setUp(self): print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) def set_config_multikrum(self, cfg): backup_cfg = cfg.clone() import torch cfg.merge_from_file( 'federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml') cfg.data.root = 'test_data/' cfg.federate.client_num = 50 cfg.federate.total_round_num = 10 cfg.aggregator.byzantine_node_num = 10 cfg.aggregator.robust_rule = 'krum' cfg.aggregator.BFT_args.krum_agg_num = 30 cfg.attack.attack_method = 'gaussian_noise' cfg.attack.attacker_id = [ i + 1 for i in range(cfg.aggregator.byzantine_node_num) ] return backup_cfg def set_config_median(self, cfg): backup_cfg = cfg.clone() import torch cfg.merge_from_file( 'federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml') cfg.data.root = 'test_data/' cfg.federate.client_num = 50 cfg.federate.total_round_num = 10 cfg.aggregator.byzantine_node_num = 10 cfg.aggregator.robust_rule = 'median' cfg.attack.attack_method = 'gaussian_noise' cfg.attack.attacker_id = [ i + 1 for i in range(cfg.aggregator.byzantine_node_num) ] return backup_cfg def set_config_trimmedmean(self, cfg): backup_cfg = cfg.clone() import torch cfg.merge_from_file( 'federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml') cfg.data.root = 'test_data/' cfg.federate.client_num = 50 cfg.federate.total_round_num = 10 cfg.aggregator.byzantine_node_num = 10 cfg.aggregator.robust_rule = 'trimmedmean' cfg.aggregator.BFT_args.trimmedmean_excluded_ratio = 0.2 cfg.attack.attack_method = 'gaussian_noise' cfg.attack.attacker_id = [ i + 1 for i in range(cfg.aggregator.byzantine_node_num) ] return backup_cfg def set_config_bulyan(self, cfg): backup_cfg = cfg.clone() import torch cfg.merge_from_file( 'federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml') cfg.data.root = 'test_data/' cfg.federate.client_num = 50 cfg.federate.total_round_num = 10 cfg.aggregator.robust_rule = 'bulyan' cfg.aggregator.byzantine_node_num = 10 cfg.attack.attack_method = 'gaussian_noise' cfg.attack.attacker_id = [ i + 1 for i in range(cfg.aggregator.byzantine_node_num) ] return backup_cfg def set_config_normbounding(self, cfg): backup_cfg = cfg.clone() import torch cfg.merge_from_file( 'federatedscope/cv/baseline/fedavg_convnet2_on_femnist.yaml') cfg.data.root = 'test_data/' cfg.federate.client_num = 50 cfg.federate.total_round_num = 10 cfg.aggregator.robust_rule = 'normbounding' cfg.aggregator.BFT_args.normbounding_norm_bound = 5 cfg.aggregator.byzantine_node_num = 10 cfg.attack.attack_method = 'gaussian_noise' cfg.attack.attacker_id = [ i + 1 for i in range(cfg.aggregator.byzantine_node_num) ] return backup_cfg def test_0_multikrum(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_multikrum(init_cfg) setup_seed(init_cfg.seed) update_logger(init_cfg, True) data, modified_cfg = get_data(init_cfg.clone()) init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) Fed_runner = get_runner(data=data, server_class=get_server_cls(init_cfg), client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results['client_summarized_weighted_avg']['test_acc'], 0.7) init_cfg.merge_from_other_cfg(backup_cfg) def test_1_median(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_median(init_cfg) setup_seed(init_cfg.seed) update_logger(init_cfg, True) data, modified_cfg = get_data(init_cfg.clone()) init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) Fed_runner = get_runner(data=data, server_class=get_server_cls(init_cfg), client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results['client_summarized_weighted_avg']['test_acc'], 0.7) init_cfg.merge_from_other_cfg(backup_cfg) def test_2_trimmedmean(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_trimmedmean(init_cfg) setup_seed(init_cfg.seed) update_logger(init_cfg, True) data, modified_cfg = get_data(init_cfg.clone()) init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) Fed_runner = get_runner(data=data, server_class=get_server_cls(init_cfg), client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results['client_summarized_weighted_avg']['test_acc'], 0.7) init_cfg.merge_from_other_cfg(backup_cfg) def test_3_bulyan(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_bulyan(init_cfg) setup_seed(init_cfg.seed) update_logger(init_cfg, True) data, modified_cfg = get_data(init_cfg.clone()) init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) Fed_runner = get_runner(data=data, server_class=get_server_cls(init_cfg), client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results['client_summarized_weighted_avg']['test_acc'], 0.7) init_cfg.merge_from_other_cfg(backup_cfg) def test_4_normbounding(self): init_cfg = global_cfg.clone() backup_cfg = self.set_config_normbounding(init_cfg) setup_seed(init_cfg.seed) update_logger(init_cfg, True) data, modified_cfg = get_data(init_cfg.clone()) init_cfg.merge_from_other_cfg(modified_cfg) self.assertIsNotNone(data) Fed_runner = get_runner(data=data, server_class=get_server_cls(init_cfg), client_class=get_client_cls(init_cfg), config=init_cfg.clone()) self.assertIsNotNone(Fed_runner) test_best_results = Fed_runner.run() print(test_best_results) init_cfg.merge_from_other_cfg(backup_cfg) self.assertLess( test_best_results['client_summarized_weighted_avg']['test_acc'], 0.7) init_cfg.merge_from_other_cfg(backup_cfg) if __name__ == '__main__': unittest.main()