41 lines
1.4 KiB
Python
41 lines
1.4 KiB
Python
import logging
|
|
|
|
from federatedscope.core.configs.config import CN
|
|
from federatedscope.register import register_config
|
|
|
|
|
|
def extend_aggregator_cfg(cfg):
|
|
|
|
# ---------------------------------------------------------------------- #
|
|
# aggregator related options
|
|
# fs has supported the robust aggregation rules of 'krum', 'median',
|
|
# 'trimmedmean', 'bulyan' and 'normbounding', the use case is refered
|
|
# to tests/test_robust_aggregators.py
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.aggregator = CN()
|
|
cfg.aggregator.robust_rule = 'fedavg'
|
|
cfg.aggregator.byzantine_node_num = 0
|
|
cfg.aggregator.BFT_args = CN(new_allowed=True)
|
|
|
|
# For ATC method
|
|
cfg.aggregator.num_agg_groups = 1
|
|
cfg.aggregator.num_agg_topk = []
|
|
cfg.aggregator.inside_weight = 1.0
|
|
cfg.aggregator.outside_weight = 0.0
|
|
|
|
# --------------- register corresponding check function ----------
|
|
cfg.register_cfg_check_fun(assert_aggregator_cfg)
|
|
|
|
|
|
def assert_aggregator_cfg(cfg):
|
|
|
|
if cfg.aggregator.byzantine_node_num == 0 and \
|
|
cfg.aggregator.robust_rule in \
|
|
['krum', 'normbounding', 'median', 'trimmedmean', 'bulyan']:
|
|
logging.warning(
|
|
f'Although {cfg.aggregator.robust_rule} aggregtion rule is '
|
|
'applied, we found that cfg.aggregator.byzantine_node_num == 0')
|
|
|
|
|
|
register_config('aggregator', extend_aggregator_cfg)
|