85 lines
2.8 KiB
Python
85 lines
2.8 KiB
Python
import collections
|
|
import json
|
|
import math
|
|
|
|
|
|
def format_log_hooks(hooks_set):
|
|
def format_dict(target_dict):
|
|
print_dict = collections.defaultdict(list)
|
|
for k, v in target_dict.items():
|
|
for element in v:
|
|
print_dict[k].append(element.__name__)
|
|
return print_dict
|
|
|
|
if isinstance(hooks_set, list):
|
|
print_obj = [format_dict(_) for _ in hooks_set]
|
|
elif isinstance(hooks_set, dict):
|
|
print_obj = format_dict(hooks_set)
|
|
return json.dumps(print_obj, indent=2).replace('\n', '\n\t')
|
|
|
|
|
|
def filter_by_specified_keywords(param_name, filter_keywords):
|
|
"""
|
|
Arguments:
|
|
param_name (str): parameter name.
|
|
Returns:
|
|
preserve (bool): whether to preserve this parameter.
|
|
"""
|
|
preserve = True
|
|
for kw in filter_keywords:
|
|
if kw in param_name:
|
|
preserve = False
|
|
break
|
|
return preserve
|
|
|
|
|
|
def move_to(obj, device):
|
|
import torch
|
|
if torch.is_tensor(obj):
|
|
return obj.to(device)
|
|
elif isinstance(obj, dict):
|
|
res = {}
|
|
for k, v in obj.items():
|
|
res[k] = move_to(v, device)
|
|
return res
|
|
elif isinstance(obj, list):
|
|
res = []
|
|
for v in obj:
|
|
res.append(move_to(v, device))
|
|
return res
|
|
else:
|
|
raise TypeError("Invalid type for move_to")
|
|
|
|
|
|
def get_random(dis_type, sample_shape, params, device):
|
|
import torch.distributions as distributions
|
|
if not hasattr(distributions, dis_type):
|
|
raise NotImplementedError("Distribution {} is not implemented, "
|
|
"please refer to ```torch.distributions```"
|
|
"(https://pytorch.org/docs/stable/ "
|
|
"distributions.html).".format(dis_type))
|
|
generator = getattr(distributions, dis_type)(**params)
|
|
return generator.sample(sample_shape=sample_shape).to(device)
|
|
|
|
|
|
def calculate_batch_epoch_num(steps, batch_or_epoch, num_data, batch_size,
|
|
drop_last):
|
|
num_batch_per_epoch = num_data // batch_size + int(
|
|
not drop_last and bool(num_data % batch_size))
|
|
if num_batch_per_epoch == 0:
|
|
raise RuntimeError(
|
|
"The number of batch is 0, please check 'batch_size' or set "
|
|
"'drop_last' as False")
|
|
elif batch_or_epoch == "epoch":
|
|
num_epoch = steps
|
|
num_batch_last_epoch = num_batch_per_epoch
|
|
num_total_batch = steps * num_batch_per_epoch
|
|
else:
|
|
num_batch_per_epoch = min(num_batch_per_epoch, steps)
|
|
num_epoch = math.ceil(steps / num_batch_per_epoch)
|
|
num_batch_last_epoch = steps % num_batch_per_epoch or \
|
|
num_batch_per_epoch
|
|
num_total_batch = steps
|
|
return num_batch_per_epoch, num_batch_last_epoch, num_epoch, \
|
|
num_total_batch
|