FS-TFP/federatedscope/core/trainers/utils.py

85 lines
2.8 KiB
Python

import collections
import json
import math
def format_log_hooks(hooks_set):
def format_dict(target_dict):
print_dict = collections.defaultdict(list)
for k, v in target_dict.items():
for element in v:
print_dict[k].append(element.__name__)
return print_dict
if isinstance(hooks_set, list):
print_obj = [format_dict(_) for _ in hooks_set]
elif isinstance(hooks_set, dict):
print_obj = format_dict(hooks_set)
return json.dumps(print_obj, indent=2).replace('\n', '\n\t')
def filter_by_specified_keywords(param_name, filter_keywords):
"""
Arguments:
param_name (str): parameter name.
Returns:
preserve (bool): whether to preserve this parameter.
"""
preserve = True
for kw in filter_keywords:
if kw in param_name:
preserve = False
break
return preserve
def move_to(obj, device):
import torch
if torch.is_tensor(obj):
return obj.to(device)
elif isinstance(obj, dict):
res = {}
for k, v in obj.items():
res[k] = move_to(v, device)
return res
elif isinstance(obj, list):
res = []
for v in obj:
res.append(move_to(v, device))
return res
else:
raise TypeError("Invalid type for move_to")
def get_random(dis_type, sample_shape, params, device):
import torch.distributions as distributions
if not hasattr(distributions, dis_type):
raise NotImplementedError("Distribution {} is not implemented, "
"please refer to ```torch.distributions```"
"(https://pytorch.org/docs/stable/ "
"distributions.html).".format(dis_type))
generator = getattr(distributions, dis_type)(**params)
return generator.sample(sample_shape=sample_shape).to(device)
def calculate_batch_epoch_num(steps, batch_or_epoch, num_data, batch_size,
drop_last):
num_batch_per_epoch = num_data // batch_size + int(
not drop_last and bool(num_data % batch_size))
if num_batch_per_epoch == 0:
raise RuntimeError(
"The number of batch is 0, please check 'batch_size' or set "
"'drop_last' as False")
elif batch_or_epoch == "epoch":
num_epoch = steps
num_batch_last_epoch = num_batch_per_epoch
num_total_batch = steps * num_batch_per_epoch
else:
num_batch_per_epoch = min(num_batch_per_epoch, steps)
num_epoch = math.ceil(steps / num_batch_per_epoch)
num_batch_last_epoch = steps % num_batch_per_epoch or \
num_batch_per_epoch
num_total_batch = steps
return num_batch_per_epoch, num_batch_last_epoch, num_epoch, \
num_total_batch