324 lines
14 KiB
Python
324 lines
14 KiB
Python
import copy
|
||
from types import FunctionType
|
||
from typing import Type
|
||
|
||
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
|
||
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
||
|
||
import numpy as np
|
||
|
||
# Modifications:
|
||
# 1. Fix issue where the trainer cannot access the monitor. Line 68
|
||
# 2. Fix issue where deepcopy cannot copy items Line 77
|
||
# (2024-10-6, czzhangheng)
|
||
|
||
class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
||
def __init__(self,
|
||
model_nums,
|
||
models_interact_mode="sequential",
|
||
model=None,
|
||
data=None,
|
||
device=None,
|
||
config=None,
|
||
monitor=None,
|
||
base_trainer: Type[GeneralTorchTrainer] = None):
|
||
"""
|
||
`GeneralMultiModelTrainer` supports train/eval via multiple
|
||
internal models
|
||
|
||
Arguments:
|
||
model_nums (int): how many internal models and optimizers
|
||
will be held by the trainer
|
||
models_interact_mode (str): how the models interact, can be
|
||
"sequential" or "parallel".
|
||
model: training model
|
||
data: a dict contains train/val/test data
|
||
device: device to run
|
||
config: for trainer-related configuration
|
||
base_trainer: if given, the GeneralMultiModelTrainer init
|
||
will based on base_trainer copy
|
||
|
||
The sequential mode indicates the interaction at
|
||
run_routine level
|
||
[one model runs its whole routine, then do sth. for
|
||
interaction, then next model runs its whole routine]
|
||
... -> run_routine_model_i
|
||
-> _switch_model_ctx
|
||
-> (on_fit_end, _interact_to_other_models)
|
||
-> run_routine_model_i+1
|
||
-> ...
|
||
|
||
The parallel mode indicates the interaction
|
||
at point-in-time level
|
||
[At a specific point-in-time, one model call hooks (
|
||
including interaction), then next model call hooks]
|
||
... -> (on_xxx_point, hook_xxx_model_i)
|
||
-> (on_xxx_point, _interact_to_other_models)
|
||
-> (on_xxx_point, _switch_model_ctx)
|
||
-> (on_xxx_point, hook_xxx_model_i+1)
|
||
-> ...
|
||
|
||
"""
|
||
# support two initialization methods for the `GeneralMultiModelTrainer`
|
||
# 1) from another trainer; or 2) standard init manner given (model,
|
||
# data, device, config)
|
||
if base_trainer is None:
|
||
assert model is not None and \
|
||
data is not None and \
|
||
device is not None and \
|
||
config is not None, "when not copy construction, (model, " \
|
||
"data, device, config) should not be " \
|
||
"None"
|
||
super(GeneralMultiModelTrainer,
|
||
self).__init__(model, data, device, config, monitor=monitor)
|
||
else:
|
||
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
|
||
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
|
||
or isinstance(base_trainer, GeneralTorchTrainer) or \
|
||
issubclass(type(base_trainer), GeneralTorchTrainer) or \
|
||
"can only copy instances of `GeneralMultiModelTrainer` " \
|
||
"and its subclasses, or " \
|
||
"`GeneralTorchTrainer` and its subclasses"
|
||
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
||
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
|
||
for key, value in base_trainer.__dict__.items():
|
||
try:
|
||
self.__dict__[key] = copy.deepcopy(value)
|
||
except TypeError:
|
||
self.__dict__[key] = value # 如果不能 deepcopy,使用浅拷贝
|
||
|
||
assert models_interact_mode in ["sequential", "parallel"], \
|
||
f"Invalid models_interact_mode, should be `sequential` or " \
|
||
f"`parallel`, but got {models_interact_mode}"
|
||
self.models_interact_mode = models_interact_mode
|
||
|
||
if int(model_nums) != model_nums or model_nums < 1:
|
||
raise ValueError(
|
||
f"model_nums should be integer and >= 1, got {model_nums}.")
|
||
self.model_nums = model_nums
|
||
|
||
self.ctx.cur_model_idx = 0 # used to mark cur model
|
||
|
||
# different internal models can have different hook_set
|
||
self.hooks_in_train_multiple_models = [self.hooks_in_train]
|
||
self.hooks_in_eval_multiple_models = [self.hooks_in_eval]
|
||
self.init_multiple_models()
|
||
self.init_multiple_model_hooks()
|
||
assert len(self.ctx.models) == model_nums == \
|
||
len(self.hooks_in_train_multiple_models) == len(
|
||
self.hooks_in_eval_multiple_models),\
|
||
"After init, len(hooks_in_train_multiple_models), " \
|
||
"len(hooks_in_eval_multiple_models), " \
|
||
"len(ctx.models) and model_nums should be the same"
|
||
|
||
def init_multiple_models(self):
|
||
"""
|
||
init multiple models and optimizers: the default implementation
|
||
is copy init manner;
|
||
========================= Extension =============================
|
||
users can override this function according to their own
|
||
requirements
|
||
"""
|
||
|
||
additional_models = [
|
||
copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1)
|
||
]
|
||
self.ctx.models = [self.ctx.model] + additional_models
|
||
|
||
self.ctx.optimizers = [
|
||
get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer)
|
||
for i in range(0, self.model_nums)
|
||
]
|
||
|
||
def register_multiple_model_hooks(self):
|
||
"""
|
||
By default, all internal models adopt the same hook_set.
|
||
|
||
Extension
|
||
Users can override this function to register customized hooks \
|
||
for different internal models.
|
||
|
||
Note:
|
||
- for sequential mode, users can append interact_hook on \
|
||
begin/end triggers such as \
|
||
" -> (on_fit_end, _interact_to_other_models) -> "
|
||
- for parallel mode, users can append interact_hook on any \
|
||
trigger they want such as \
|
||
" -> (on_xxx_point, _interact_to_other_models) -> "
|
||
- we must tell the running hooks which data_loader to \
|
||
call and which num_samples to count
|
||
"""
|
||
|
||
self.hooks_in_train_multiple_models.extend([
|
||
self.hooks_in_train_multiple_models[0]
|
||
for _ in range(1, self.model_nums)
|
||
])
|
||
self.hooks_in_eval_multiple_models.extend([
|
||
self.hooks_in_eval_multiple_models[0]
|
||
for _ in range(1, self.model_nums)
|
||
])
|
||
|
||
def init_multiple_model_hooks(self):
|
||
self.register_multiple_model_hooks()
|
||
if self.models_interact_mode == "sequential":
|
||
# hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores
|
||
# specific set for i-th internal model;
|
||
# for each dict, the key indicates point-in-time and the value
|
||
# indicates specific hook
|
||
self.hooks_in_train = self.hooks_in_train_multiple_models
|
||
self.hooks_in_eval = self.hooks_in_eval_multiple_models
|
||
elif self.models_interact_mode == "parallel":
|
||
# hooks_in_xxx is a dict whose key indicates point-in-time and
|
||
# value indicates specific hook
|
||
for trigger in list(self.hooks_in_train.keys()):
|
||
self.hooks_in_train[trigger] = []
|
||
self.hooks_in_eval[trigger] = []
|
||
for model_idx in range(len(self.ctx.models)):
|
||
self.hooks_in_train[trigger].extend(
|
||
self.hooks_in_train_multiple_models[model_idx]
|
||
[trigger])
|
||
self.hooks_in_train[trigger].extend(
|
||
[self._switch_model_ctx])
|
||
self.hooks_in_eval[trigger].extend(
|
||
self.hooks_in_eval_multiple_models[model_idx][trigger])
|
||
self.hooks_in_eval[trigger].extend(
|
||
[self._switch_model_ctx])
|
||
else:
|
||
raise RuntimeError(
|
||
f"Invalid models_interact_mode, should be `sequential` or "
|
||
f"`parallel`,"
|
||
f" but got {self.models_interact_mode}")
|
||
|
||
def register_hook_in_train(self,
|
||
new_hook,
|
||
trigger,
|
||
model_idx=0,
|
||
insert_pos=None,
|
||
base_hook=None,
|
||
insert_mode="before"):
|
||
hooks_dict = self.hooks_in_train_multiple_models[model_idx]
|
||
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
|
||
new_hook, trigger)
|
||
|
||
def register_hook_in_eval(self,
|
||
new_hook,
|
||
trigger,
|
||
model_idx=0,
|
||
insert_pos=None,
|
||
base_hook=None,
|
||
insert_mode="before"):
|
||
hooks_dict = self.hooks_in_eval_multiple_models[model_idx]
|
||
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
|
||
new_hook, trigger)
|
||
|
||
def _switch_model_ctx(self, next_model_idx=None):
|
||
if next_model_idx is None:
|
||
next_model_idx = (self.ctx.cur_model_idx + 1) % len(
|
||
self.ctx.models)
|
||
self.ctx.cur_model_idx = next_model_idx
|
||
self.ctx.model = self.ctx.models[next_model_idx]
|
||
self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
|
||
|
||
def _run_routine(self, mode, hooks_set, dataset_name=None):
|
||
"""Run the hooks_set and maintain the mode for multiple internal models
|
||
|
||
Arguments:
|
||
mode: running mode of client, chosen from train/val/test
|
||
|
||
Note:
|
||
Considering evaluation could be in ```hooks_set[
|
||
"on_epoch_end"]```, there could be two data loaders in \
|
||
self.ctx, we must tell the running hooks which data_loader to \
|
||
call and which num_samples to count
|
||
|
||
"""
|
||
num_samples_model = list()
|
||
if self.models_interact_mode == "sequential":
|
||
assert isinstance(hooks_set, list) and isinstance(hooks_set[0],
|
||
dict), \
|
||
"When models_interact_mode=sequential, " \
|
||
"hooks_set should be a list of dict" \
|
||
"hooks_set[i] stores specific set for i-th internal model." \
|
||
"For each dict, the key indicates point-in-time and the " \
|
||
"value indicates specific hook"
|
||
for model_idx in range(len(self.ctx.models)):
|
||
# switch different hooks & ctx for different internal models
|
||
hooks_set_model_i = hooks_set[model_idx]
|
||
self._switch_model_ctx(model_idx)
|
||
# [Interaction at run_routine level]
|
||
# one model runs its whole routine, then do sth. for
|
||
# interaction, then next model runs its whole routine
|
||
# ... -> run_routine_model_i
|
||
# -> _switch_model_ctx
|
||
# -> (on_fit_end, _interact_to_other_models)
|
||
# -> run_routine_model_i+1
|
||
# -> ...
|
||
num_samples = super()._run_routine(mode, hooks_set_model_i,
|
||
dataset_name)
|
||
num_samples_model.append(num_samples)
|
||
elif self.models_interact_mode == "parallel":
|
||
assert isinstance(hooks_set, dict), \
|
||
"When models_interact_mode=parallel, hooks_set should be a " \
|
||
"dict whose key indicates point-in-time and value indicates " \
|
||
"specific hook"
|
||
# [Interaction at point-in-time level]
|
||
# at a specific point-in-time, one model call hooks (including
|
||
# interaction), then next model call hooks
|
||
# ... -> (on_xxx_point, hook_xxx_model_i)
|
||
# -> (on_xxx_point, _interact_to_other_models)
|
||
# -> (on_xxx_point, _switch_model_ctx)
|
||
# -> (on_xxx_point, hook_xxx_model_i+1)
|
||
# -> ...
|
||
num_samples = super()._run_routine(mode, hooks_set, dataset_name)
|
||
num_samples_model.append(num_samples)
|
||
else:
|
||
raise RuntimeError(
|
||
f"Invalid models_interact_mode, should be `sequential` or "
|
||
f"`parallel`,"
|
||
f" but got {self.models_interact_mode}")
|
||
# For now, we return the average number of samples for different models
|
||
return np.mean(num_samples_model)
|
||
|
||
def get_model_para(self):
|
||
"""
|
||
return multiple model parameters
|
||
:return:
|
||
"""
|
||
trained_model_para = []
|
||
for model_idx in range(self.model_nums):
|
||
trained_model_para.append(
|
||
self._param_filter(
|
||
self.ctx.models[model_idx].cpu().state_dict()))
|
||
|
||
return trained_model_para[
|
||
0] if self.model_nums == 1 else trained_model_para
|
||
|
||
def update(self, model_parameters, strict=False):
|
||
# update multiple model paras
|
||
"""
|
||
Arguments:
|
||
model_parameters (list[dict]): Multiple pyTorch Module object's
|
||
state_dict.
|
||
"""
|
||
if self.model_nums == 1:
|
||
super().update(model_parameters, strict=strict)
|
||
else:
|
||
assert isinstance(model_parameters, list) and isinstance(
|
||
model_parameters[0], dict), \
|
||
"model_parameters should a list of multiple state_dict"
|
||
assert len(model_parameters) == self.model_nums, \
|
||
f"model_parameters should has the same length to " \
|
||
f"self.model_nums, " \
|
||
f"but got {len(model_parameters)} and {self.model_nums} " \
|
||
f"respectively"
|
||
for model_idx in range(self.model_nums):
|
||
self.ctx.models[model_idx].load_state_dict(self._param_filter(
|
||
model_parameters[model_idx]),
|
||
strict=strict)
|
||
|
||
def train(self, target_data_split_name="train"):
|
||
# return multiple model paras
|
||
sample_size, _, results = super().train(target_data_split_name)
|
||
|
||
return sample_size, self.get_model_para(), results
|