FS-TFP/federatedscope/core/optimizer.py

60 lines
2.6 KiB
Python

import copy
from typing import Dict, List
def wrap_regularized_optimizer(base_optimizer, regular_weight):
base_optimizer_type = type(base_optimizer)
internal_base_optimizer = copy.copy(
base_optimizer) # shallow copy to link the underlying model para
class ParaRegularOptimizer(base_optimizer_type):
"""
Regularization-based optimizer wrapper
"""
def __init__(self, base_optimizer, regular_weight):
# inherit all the attributes of base optimizer
self.__dict__.update(base_optimizer.__dict__)
# attributes used in the wrapper
self.optimizer = base_optimizer # internal torch optimizer
self.param_groups = self.optimizer.param_groups # link the para
# of internal optimizer with the wrapper
self.regular_weight = regular_weight
self.compared_para_groups = None
def set_compared_para_group(self, compared_para_dict: List[Dict]):
if not (isinstance(compared_para_dict, list)
and isinstance(compared_para_dict[0], dict)
and 'params' in compared_para_dict[0]):
raise ValueError(
"compared_para_dict should be a torch style para group, "
"i.e., list[dict], "
"in which the dict stores the para with key `params`")
self.compared_para_groups = copy.deepcopy(compared_para_dict)
def reset_compared_para_group(self, target=None):
# by default, del stale compared_para to free memory
self.compared_para_groups = target
def regularize_by_para_diff(self):
"""
before optim.step(), regularize the gradients based on para
diff
"""
for group, compared_group in zip(self.param_groups,
self.compared_para_groups):
for p, compared_weight in zip(group['params'],
compared_group['params']):
if p.grad is not None:
if compared_weight.device != p.device:
# For Tensor, the to() is not in-place operation
compared_weight = compared_weight.to(p.device)
p.grad.data = p.grad.data + self.regular_weight * (
p.data - compared_weight.data)
def step(self):
self.regularize_by_para_diff() # key action
self.optimizer.step()
return ParaRegularOptimizer(internal_base_optimizer, regular_weight)