FS-TFP/federatedscope/core/configs/cfg_data.py

198 lines
7.9 KiB
Python

import logging
from federatedscope.core.configs.config import CN
from federatedscope.register import register_config
logger = logging.getLogger(__name__)
def extend_data_cfg(cfg):
# ---------------------------------------------------------------------- #
# Dataset related options
# ---------------------------------------------------------------------- #
cfg.data = CN()
cfg.data.root = 'data'
cfg.data.type = 'toy'
cfg.data.save_data = False # whether to save the generated toy data
cfg.data.args = [] # args for external dataset, eg. [{'download': True}]
cfg.data.splitter = ''
cfg.data.splitter_args = [] # args for splitter, eg. [{'alpha': 0.5}]
cfg.data.server_holds_all = False # whether the server (workers with
# idx 0) holds all data, useful in global training/evaluation case
cfg.data.subsample = 1.0
cfg.data.splits = [0.8, 0.1, 0.1] # Train, valid, test splits
cfg.data.consistent_label_distribution = True # If True, the label
# distributions of train/val/test set over clients will be kept
# consistent during splitting
cfg.data.cSBM_phi = [0.5, 0.5, 0.5]
cfg.data.transform = [
] # transform for x, eg. [['ToTensor'], ['Normalize', {'mean': [
# 0.9637], 'std': [0.1592]}]]
cfg.data.target_transform = [] # target_transform for y, use as above
cfg.data.pre_transform = [
] # pre_transform for `torch_geometric` dataset, use as above
# If not provided, use `cfg.data.transform` for all splits
cfg.data.val_transform = []
cfg.data.val_target_transform = []
cfg.data.val_pre_transform = []
cfg.data.test_transform = []
cfg.data.test_target_transform = []
cfg.data.test_pre_transform = []
# data.file_path takes effect when data.type = 'files'
cfg.data.file_path = ''
# DataLoader related args
cfg.dataloader = CN()
cfg.dataloader.type = 'base'
cfg.dataloader.batch_size = 64
cfg.dataloader.shuffle = True
cfg.dataloader.num_workers = 0
cfg.dataloader.drop_last = False
cfg.dataloader.pin_memory = False
# GFL: graphsaint DataLoader
cfg.dataloader.walk_length = 2
cfg.dataloader.num_steps = 30
# GFL: neighbor sampler DataLoader
cfg.dataloader.sizes = [10, 5]
# DP: -1 means per-rating privacy, otherwise per-user privacy
cfg.dataloader.theta = -1
# quadratic
cfg.data.quadratic = CN()
cfg.data.quadratic.dim = 1
cfg.data.quadratic.min_curv = 0.02
cfg.data.quadratic.max_curv = 12.5
# Hetero NLP tasks data (for ATC)
cfg.data.hetero_data_name = [] # multiple datasets
cfg.data.num_of_client_for_data = [
] # each dataset can be splited into several clients
cfg.data.max_seq_len = 384
cfg.data.max_tgt_len = 128
cfg.data.max_query_len = 128
cfg.data.trunc_stride = 128
cfg.data.cache_dir = ''
cfg.data.hetero_synth_batch_size = 32
cfg.data.hetero_synth_prim_weight = 0.5
cfg.data.hetero_synth_feat_dim = 128
cfg.data.num_contrast = 0
cfg.data.is_debug = False
# Traffic Flow data parameters, These are only default values.
# Please modify the specific parameters directly in the YAML files.
cfg.data.root = 'data/trafficflow/PeMS04'
cfg.data.type = 'trafficflow'
cfg.data.num_nodes = 307
cfg.data.lag = 12
cfg.data.horizon = 12
cfg.data.val_ratio = 0.2
cfg.data.test_ratio = 0.2
cfg.data.tod = False
cfg.data.normalizer = 'std'
cfg.data.column_wise = False
cfg.data.default_graph = True
cfg.data.add_time_in_day = True
cfg.data.add_day_in_week = True
cfg.data.steps_per_day = 288
cfg.data.days_per_week = 7
cfg.data.scaler = [0,0]
# feature engineering
cfg.feat_engr = CN()
cfg.feat_engr.type = ''
cfg.feat_engr.scenario = 'hfl'
cfg.feat_engr.num_bins = 5 # Used for binning
cfg.feat_engr.selec_threshold = 0.05 # Used for feature selection
cfg.feat_engr.selec_woe_binning = 'quantile'
cfg.feat_engr.secure = CN()
cfg.feat_engr.secure.type = 'encrypt'
cfg.feat_engr.secure.key_size = 3072
cfg.feat_engr.secure.encrypt = CN()
cfg.feat_engr.secure.encrypt.type = 'dummy'
cfg.feat_engr.secure.dp = CN() # under dev
# --------------- outdated configs ---------------
# TODO: delete this code block
cfg.data.loader = ''
cfg.data.batch_size = 64
cfg.data.shuffle = True
cfg.data.num_workers = 0
cfg.data.drop_last = False
cfg.data.walk_length = 2
cfg.data.num_steps = 30
cfg.data.sizes = [10, 5]
# --------------- register corresponding check function ----------
cfg.register_cfg_check_fun(assert_data_cfg)
def assert_data_cfg(cfg):
if cfg.dataloader.type == 'graphsaint-rw':
assert cfg.model.layer == cfg.dataloader.walk_length, 'Sample ' \
'size ' \
'mismatch'
if cfg.dataloader.type == 'neighbor':
assert cfg.model.layer == len(
cfg.dataloader.sizes), 'Sample size mismatch'
if '@' in cfg.data.type:
assert cfg.federate.client_num > 0, '`federate.client_num` should ' \
'be greater than 0 when using ' \
'external data'
assert cfg.data.splitter, '`data.splitter` should not be empty when ' \
'using external data'
# hetero NLP taks data
if len(cfg.data.num_of_client_for_data) > 0:
assert cfg.federate.client_num == \
sum(cfg.data.num_of_client_for_data), '`federate.client_num` ' \
'should be equal to sum '\
'of `data.num_of_client'\
'_for_data`'
# --------------------------------------------------------------------
# For compatibility with older versions of FS
# TODO: delete this code block
if cfg.data.loader != '':
logger.warning('config `cfg.data.loader` will be removed in the '
'future, use `cfg.dataloader.type` instead.')
cfg.dataloader.type = cfg.data.loader
if cfg.data.batch_size != 64:
logger.warning('config `cfg.data.batch_size` will be removed in the '
'future, use `cfg.dataloader.batch_size` instead.')
cfg.dataloader.batch_size = cfg.data.batch_size
if not cfg.data.shuffle:
logger.warning('config `cfg.data.shuffle` will be removed in the '
'future, use `cfg.dataloader.shuffle` instead.')
cfg.dataloader.shuffle = cfg.data.shuffle
if cfg.data.num_workers != 0:
logger.warning('config `cfg.data.num_workers` will be removed in the '
'future, use `cfg.dataloader.num_workers` instead.')
cfg.dataloader.num_workers = cfg.data.num_workers
if cfg.data.drop_last:
logger.warning('config `cfg.data.drop_last` will be removed in the '
'future, use `cfg.dataloader.drop_last` instead.')
cfg.dataloader.drop_last = cfg.data.drop_last
if cfg.data.walk_length != 2:
logger.warning('config `cfg.data.walk_length` will be removed in the '
'future, use `cfg.dataloader.walk_length` instead.')
cfg.dataloader.walk_length = cfg.data.walk_length
if cfg.data.num_steps != 30:
logger.warning('config `cfg.data.num_steps` will be removed in the '
'future, use `cfg.dataloader.num_steps` instead.')
cfg.dataloader.num_steps = cfg.data.num_steps
if cfg.data.sizes != [10, 5]:
logger.warning('config `cfg.data.sizes` will be removed in the '
'future, use `cfg.dataloader.sizes` instead.')
cfg.dataloader.sizes = cfg.data.sizes
# --------------------------------------------------------------------
register_config("data", extend_data_cfg)