198 lines
7.9 KiB
Python
198 lines
7.9 KiB
Python
import logging
|
|
|
|
from federatedscope.core.configs.config import CN
|
|
from federatedscope.register import register_config
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def extend_data_cfg(cfg):
|
|
# ---------------------------------------------------------------------- #
|
|
# Dataset related options
|
|
# ---------------------------------------------------------------------- #
|
|
cfg.data = CN()
|
|
|
|
cfg.data.root = 'data'
|
|
cfg.data.type = 'toy'
|
|
cfg.data.save_data = False # whether to save the generated toy data
|
|
cfg.data.args = [] # args for external dataset, eg. [{'download': True}]
|
|
cfg.data.splitter = ''
|
|
cfg.data.splitter_args = [] # args for splitter, eg. [{'alpha': 0.5}]
|
|
cfg.data.server_holds_all = False # whether the server (workers with
|
|
# idx 0) holds all data, useful in global training/evaluation case
|
|
cfg.data.subsample = 1.0
|
|
cfg.data.splits = [0.8, 0.1, 0.1] # Train, valid, test splits
|
|
cfg.data.consistent_label_distribution = True # If True, the label
|
|
# distributions of train/val/test set over clients will be kept
|
|
# consistent during splitting
|
|
cfg.data.cSBM_phi = [0.5, 0.5, 0.5]
|
|
|
|
cfg.data.transform = [
|
|
] # transform for x, eg. [['ToTensor'], ['Normalize', {'mean': [
|
|
# 0.9637], 'std': [0.1592]}]]
|
|
cfg.data.target_transform = [] # target_transform for y, use as above
|
|
cfg.data.pre_transform = [
|
|
] # pre_transform for `torch_geometric` dataset, use as above
|
|
|
|
# If not provided, use `cfg.data.transform` for all splits
|
|
cfg.data.val_transform = []
|
|
cfg.data.val_target_transform = []
|
|
cfg.data.val_pre_transform = []
|
|
cfg.data.test_transform = []
|
|
cfg.data.test_target_transform = []
|
|
cfg.data.test_pre_transform = []
|
|
|
|
# data.file_path takes effect when data.type = 'files'
|
|
cfg.data.file_path = ''
|
|
|
|
# DataLoader related args
|
|
cfg.dataloader = CN()
|
|
cfg.dataloader.type = 'base'
|
|
cfg.dataloader.batch_size = 64
|
|
cfg.dataloader.shuffle = True
|
|
cfg.dataloader.num_workers = 0
|
|
cfg.dataloader.drop_last = False
|
|
cfg.dataloader.pin_memory = False
|
|
# GFL: graphsaint DataLoader
|
|
cfg.dataloader.walk_length = 2
|
|
cfg.dataloader.num_steps = 30
|
|
# GFL: neighbor sampler DataLoader
|
|
cfg.dataloader.sizes = [10, 5]
|
|
# DP: -1 means per-rating privacy, otherwise per-user privacy
|
|
cfg.dataloader.theta = -1
|
|
|
|
# quadratic
|
|
cfg.data.quadratic = CN()
|
|
cfg.data.quadratic.dim = 1
|
|
cfg.data.quadratic.min_curv = 0.02
|
|
cfg.data.quadratic.max_curv = 12.5
|
|
|
|
# Hetero NLP tasks data (for ATC)
|
|
cfg.data.hetero_data_name = [] # multiple datasets
|
|
cfg.data.num_of_client_for_data = [
|
|
] # each dataset can be splited into several clients
|
|
cfg.data.max_seq_len = 384
|
|
cfg.data.max_tgt_len = 128
|
|
cfg.data.max_query_len = 128
|
|
cfg.data.trunc_stride = 128
|
|
cfg.data.cache_dir = ''
|
|
cfg.data.hetero_synth_batch_size = 32
|
|
cfg.data.hetero_synth_prim_weight = 0.5
|
|
cfg.data.hetero_synth_feat_dim = 128
|
|
cfg.data.num_contrast = 0
|
|
cfg.data.is_debug = False
|
|
|
|
# Traffic Flow data parameters, These are only default values.
|
|
# Please modify the specific parameters directly in the YAML files.
|
|
cfg.data.root = 'data/trafficflow/PeMS04'
|
|
cfg.data.type = 'trafficflow'
|
|
cfg.data.num_nodes = 307
|
|
cfg.data.lag = 12
|
|
cfg.data.horizon = 12
|
|
cfg.data.val_ratio = 0.2
|
|
cfg.data.test_ratio = 0.2
|
|
cfg.data.tod = False
|
|
cfg.data.normalizer = 'std'
|
|
cfg.data.column_wise = False
|
|
cfg.data.default_graph = True
|
|
cfg.data.add_time_in_day = True
|
|
cfg.data.add_day_in_week = True
|
|
cfg.data.steps_per_day = 288
|
|
cfg.data.days_per_week = 7
|
|
cfg.data.scaler = [0,0]
|
|
|
|
# feature engineering
|
|
cfg.feat_engr = CN()
|
|
cfg.feat_engr.type = ''
|
|
cfg.feat_engr.scenario = 'hfl'
|
|
cfg.feat_engr.num_bins = 5 # Used for binning
|
|
cfg.feat_engr.selec_threshold = 0.05 # Used for feature selection
|
|
cfg.feat_engr.selec_woe_binning = 'quantile'
|
|
|
|
cfg.feat_engr.secure = CN()
|
|
cfg.feat_engr.secure.type = 'encrypt'
|
|
cfg.feat_engr.secure.key_size = 3072
|
|
|
|
cfg.feat_engr.secure.encrypt = CN()
|
|
cfg.feat_engr.secure.encrypt.type = 'dummy'
|
|
|
|
cfg.feat_engr.secure.dp = CN() # under dev
|
|
|
|
# --------------- outdated configs ---------------
|
|
# TODO: delete this code block
|
|
cfg.data.loader = ''
|
|
cfg.data.batch_size = 64
|
|
cfg.data.shuffle = True
|
|
cfg.data.num_workers = 0
|
|
cfg.data.drop_last = False
|
|
cfg.data.walk_length = 2
|
|
cfg.data.num_steps = 30
|
|
cfg.data.sizes = [10, 5]
|
|
|
|
# --------------- register corresponding check function ----------
|
|
cfg.register_cfg_check_fun(assert_data_cfg)
|
|
|
|
|
|
def assert_data_cfg(cfg):
|
|
if cfg.dataloader.type == 'graphsaint-rw':
|
|
assert cfg.model.layer == cfg.dataloader.walk_length, 'Sample ' \
|
|
'size ' \
|
|
'mismatch'
|
|
if cfg.dataloader.type == 'neighbor':
|
|
assert cfg.model.layer == len(
|
|
cfg.dataloader.sizes), 'Sample size mismatch'
|
|
if '@' in cfg.data.type:
|
|
assert cfg.federate.client_num > 0, '`federate.client_num` should ' \
|
|
'be greater than 0 when using ' \
|
|
'external data'
|
|
assert cfg.data.splitter, '`data.splitter` should not be empty when ' \
|
|
'using external data'
|
|
|
|
# hetero NLP taks data
|
|
if len(cfg.data.num_of_client_for_data) > 0:
|
|
assert cfg.federate.client_num == \
|
|
sum(cfg.data.num_of_client_for_data), '`federate.client_num` ' \
|
|
'should be equal to sum '\
|
|
'of `data.num_of_client'\
|
|
'_for_data`'
|
|
|
|
# --------------------------------------------------------------------
|
|
# For compatibility with older versions of FS
|
|
# TODO: delete this code block
|
|
if cfg.data.loader != '':
|
|
logger.warning('config `cfg.data.loader` will be removed in the '
|
|
'future, use `cfg.dataloader.type` instead.')
|
|
cfg.dataloader.type = cfg.data.loader
|
|
if cfg.data.batch_size != 64:
|
|
logger.warning('config `cfg.data.batch_size` will be removed in the '
|
|
'future, use `cfg.dataloader.batch_size` instead.')
|
|
cfg.dataloader.batch_size = cfg.data.batch_size
|
|
if not cfg.data.shuffle:
|
|
logger.warning('config `cfg.data.shuffle` will be removed in the '
|
|
'future, use `cfg.dataloader.shuffle` instead.')
|
|
cfg.dataloader.shuffle = cfg.data.shuffle
|
|
if cfg.data.num_workers != 0:
|
|
logger.warning('config `cfg.data.num_workers` will be removed in the '
|
|
'future, use `cfg.dataloader.num_workers` instead.')
|
|
cfg.dataloader.num_workers = cfg.data.num_workers
|
|
if cfg.data.drop_last:
|
|
logger.warning('config `cfg.data.drop_last` will be removed in the '
|
|
'future, use `cfg.dataloader.drop_last` instead.')
|
|
cfg.dataloader.drop_last = cfg.data.drop_last
|
|
if cfg.data.walk_length != 2:
|
|
logger.warning('config `cfg.data.walk_length` will be removed in the '
|
|
'future, use `cfg.dataloader.walk_length` instead.')
|
|
cfg.dataloader.walk_length = cfg.data.walk_length
|
|
if cfg.data.num_steps != 30:
|
|
logger.warning('config `cfg.data.num_steps` will be removed in the '
|
|
'future, use `cfg.dataloader.num_steps` instead.')
|
|
cfg.dataloader.num_steps = cfg.data.num_steps
|
|
if cfg.data.sizes != [10, 5]:
|
|
logger.warning('config `cfg.data.sizes` will be removed in the '
|
|
'future, use `cfg.dataloader.sizes` instead.')
|
|
cfg.dataloader.sizes = cfg.data.sizes
|
|
# --------------------------------------------------------------------
|
|
|
|
|
|
register_config("data", extend_data_cfg)
|