import logging import numpy as np from federatedscope.core.feature.utils import merge_splits_feat logger = logging.getLogger(__name__) def wrap_standardization(worker): """ This function is to perform z-norm/standardization for vfl tabular data. Args: worker: ``federatedscope.core.workers.Worker`` to be wrapped Returns: Wrap worker z-norm/standardization data """ logger.info('Start to execute standardization.') # Merge train & val & test merged_feat, _ = merge_splits_feat(worker.data) feat_mean = np.mean(merged_feat, axis=0) feat_std = np.std(merged_feat, axis=0) for split in ['train_data', 'val_data', 'test_data']: if hasattr(worker.data, split): split_data = getattr(worker.data, split) if split_data is not None and 'x' in split_data: split_data['x'] = (split_data['x'] - feat_mean) / feat_std worker._init_data_related_var() return worker def wrap_standardization_client(worker): return wrap_standardization(worker) def wrap_standardization_server(worker): return wrap_standardization(worker)