FS-TFP/federatedscope/core/configs/cfg_compression.py

39 lines
1.3 KiB
Python

import logging
from federatedscope.core.configs.config import CN
from federatedscope.register import register_config
logger = logging.getLogger(__name__)
def extend_compression_cfg(cfg):
# ---------------------------------------------------------------------- #
# Compression (for communication efficiency) related options
# ---------------------------------------------------------------------- #
cfg.quantization = CN()
# Params
cfg.quantization.method = 'none' # ['none', 'uniform']
cfg.quantization.nbits = 8 # [8,16]
# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_compression_cfg)
def assert_compression_cfg(cfg):
if cfg.quantization.method.lower() not in ['none', 'uniform']:
logger.warning(
f'Quantization method is expected to be one of ["none",'
f'"uniform"], but got "{cfg.quantization.method}". So we '
f'change it to "none"')
if cfg.quantization.method.lower(
) != 'none' and cfg.quantization.nbits not in [8, 16]:
raise ValueError(f'The value of cfg.quantization.nbits is invalid, '
f'which is expected to be one on [8, 16] but got '
f'{cfg.quantization.nbits}.')
register_config("compression", extend_compression_cfg)