84 lines
2.9 KiB
Python
84 lines
2.9 KiB
Python
from importlib import import_module
|
|
import federatedscope.register as register
|
|
|
|
|
|
def get_transform(config, package):
|
|
"""
|
|
This function is to build transforms applying to dataset.
|
|
|
|
Args:
|
|
config: ``CN`` from ``federatedscope/core/configs/config.py``
|
|
package: one of package from \
|
|
``['torchvision', 'torch_geometric', 'torchtext', 'torchaudio']``
|
|
|
|
Returns:
|
|
Dict of transform functions.
|
|
"""
|
|
transform_funcs = {}
|
|
for name in ['transform', 'target_transform', 'pre_transform']:
|
|
if config.data[name]:
|
|
transform_funcs[name] = config.data[name]
|
|
|
|
val_transform_funcs = {}
|
|
for name in ['val_transform', 'val_target_transform', 'val_pre_transform']:
|
|
suf_name = name.split('val_')[1]
|
|
if config.data[name]:
|
|
val_transform_funcs[suf_name] = config.data[name]
|
|
|
|
test_transform_funcs = {}
|
|
for name in [
|
|
'test_transform', 'test_target_transform', 'test_pre_transform'
|
|
]:
|
|
suf_name = name.split('test_')[1]
|
|
if config.data[name]:
|
|
test_transform_funcs[suf_name] = config.data[name]
|
|
|
|
# Transform are all `[]`, do not import package and return dict with
|
|
# None value
|
|
if len(transform_funcs) == 0 and len(val_transform_funcs) == 0 and len(
|
|
test_transform_funcs) == 0:
|
|
return {}, {}, {}
|
|
|
|
transforms = getattr(import_module(package), 'transforms')
|
|
|
|
def convert(trans):
|
|
# Recursively converting expressions to functions
|
|
if isinstance(trans[0], str):
|
|
if len(trans) == 1:
|
|
trans.append({})
|
|
transform_type, transform_args = trans
|
|
for func in register.transform_dict.values():
|
|
transform_func = func(transform_type, transform_args)
|
|
if transform_func is not None:
|
|
return transform_func
|
|
transform_func = getattr(transforms,
|
|
transform_type)(**transform_args)
|
|
return transform_func
|
|
else:
|
|
transform = [convert(x) for x in trans]
|
|
if hasattr(transforms, 'Compose'):
|
|
return transforms.Compose(transform)
|
|
elif hasattr(transforms, 'Sequential'):
|
|
return transforms.Sequential(transform)
|
|
else:
|
|
return transform
|
|
|
|
# return composed transform or return list of transform
|
|
if transform_funcs:
|
|
for key in transform_funcs:
|
|
transform_funcs[key] = convert(config.data[key])
|
|
|
|
if val_transform_funcs:
|
|
for key in val_transform_funcs:
|
|
val_transform_funcs[key] = convert(config.data[key])
|
|
else:
|
|
val_transform_funcs = transform_funcs
|
|
|
|
if test_transform_funcs:
|
|
for key in test_transform_funcs:
|
|
test_transform_funcs[key] = convert(config.data[key])
|
|
else:
|
|
test_transform_funcs = transform_funcs
|
|
|
|
return transform_funcs, val_transform_funcs, test_transform_funcs
|