41 lines
1.1 KiB
Python
41 lines
1.1 KiB
Python
import logging
|
|
import numpy as np
|
|
|
|
from federatedscope.core.feature.utils import merge_splits_feat
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def wrap_standardization(worker):
|
|
"""
|
|
This function is to perform z-norm/standardization for vfl tabular data.
|
|
Args:
|
|
worker: ``federatedscope.core.workers.Worker`` to be wrapped
|
|
|
|
Returns:
|
|
Wrap worker z-norm/standardization data
|
|
"""
|
|
logger.info('Start to execute standardization.')
|
|
|
|
# Merge train & val & test
|
|
merged_feat, _ = merge_splits_feat(worker.data)
|
|
|
|
feat_mean = np.mean(merged_feat, axis=0)
|
|
feat_std = np.std(merged_feat, axis=0)
|
|
|
|
for split in ['train_data', 'val_data', 'test_data']:
|
|
if hasattr(worker.data, split):
|
|
split_data = getattr(worker.data, split)
|
|
if split_data is not None and 'x' in split_data:
|
|
split_data['x'] = (split_data['x'] - feat_mean) / feat_std
|
|
worker._init_data_related_var()
|
|
return worker
|
|
|
|
|
|
def wrap_standardization_client(worker):
|
|
return wrap_standardization(worker)
|
|
|
|
|
|
def wrap_standardization_server(worker):
|
|
return wrap_standardization(worker)
|