FS-TFP/federatedscope/core/feature/vfl/preprocess/standardization.py

41 lines
1.1 KiB
Python

import logging
import numpy as np
from federatedscope.core.feature.utils import merge_splits_feat
logger = logging.getLogger(__name__)
def wrap_standardization(worker):
"""
This function is to perform z-norm/standardization for vfl tabular data.
Args:
worker: ``federatedscope.core.workers.Worker`` to be wrapped
Returns:
Wrap worker z-norm/standardization data
"""
logger.info('Start to execute standardization.')
# Merge train & val & test
merged_feat, _ = merge_splits_feat(worker.data)
feat_mean = np.mean(merged_feat, axis=0)
feat_std = np.std(merged_feat, axis=0)
for split in ['train_data', 'val_data', 'test_data']:
if hasattr(worker.data, split):
split_data = getattr(worker.data, split)
if split_data is not None and 'x' in split_data:
split_data['x'] = (split_data['x'] - feat_mean) / feat_std
worker._init_data_related_var()
return worker
def wrap_standardization_client(worker):
return wrap_standardization(worker)
def wrap_standardization_server(worker):
return wrap_standardization(worker)