324 lines
14 KiB
Python
324 lines
14 KiB
Python
import copy
|
|
from types import FunctionType
|
|
from typing import Type
|
|
|
|
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
|
|
|
import numpy as np
|
|
|
|
# Modifications:
|
|
# 1. Fix issue where the trainer cannot access the monitor. Line 68
|
|
# 2. Fix issue where deepcopy cannot copy items Line 77
|
|
# (2024-10-6, czzhangheng)
|
|
|
|
class GeneralMultiModelTrainer(GeneralTorchTrainer):
|
|
def __init__(self,
|
|
model_nums,
|
|
models_interact_mode="sequential",
|
|
model=None,
|
|
data=None,
|
|
device=None,
|
|
config=None,
|
|
monitor=None,
|
|
base_trainer: Type[GeneralTorchTrainer] = None):
|
|
"""
|
|
`GeneralMultiModelTrainer` supports train/eval via multiple
|
|
internal models
|
|
|
|
Arguments:
|
|
model_nums (int): how many internal models and optimizers
|
|
will be held by the trainer
|
|
models_interact_mode (str): how the models interact, can be
|
|
"sequential" or "parallel".
|
|
model: training model
|
|
data: a dict contains train/val/test data
|
|
device: device to run
|
|
config: for trainer-related configuration
|
|
base_trainer: if given, the GeneralMultiModelTrainer init
|
|
will based on base_trainer copy
|
|
|
|
The sequential mode indicates the interaction at
|
|
run_routine level
|
|
[one model runs its whole routine, then do sth. for
|
|
interaction, then next model runs its whole routine]
|
|
... -> run_routine_model_i
|
|
-> _switch_model_ctx
|
|
-> (on_fit_end, _interact_to_other_models)
|
|
-> run_routine_model_i+1
|
|
-> ...
|
|
|
|
The parallel mode indicates the interaction
|
|
at point-in-time level
|
|
[At a specific point-in-time, one model call hooks (
|
|
including interaction), then next model call hooks]
|
|
... -> (on_xxx_point, hook_xxx_model_i)
|
|
-> (on_xxx_point, _interact_to_other_models)
|
|
-> (on_xxx_point, _switch_model_ctx)
|
|
-> (on_xxx_point, hook_xxx_model_i+1)
|
|
-> ...
|
|
|
|
"""
|
|
# support two initialization methods for the `GeneralMultiModelTrainer`
|
|
# 1) from another trainer; or 2) standard init manner given (model,
|
|
# data, device, config)
|
|
if base_trainer is None:
|
|
assert model is not None and \
|
|
data is not None and \
|
|
device is not None and \
|
|
config is not None, "when not copy construction, (model, " \
|
|
"data, device, config) should not be " \
|
|
"None"
|
|
super(GeneralMultiModelTrainer,
|
|
self).__init__(model, data, device, config, monitor=monitor)
|
|
else:
|
|
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
|
|
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
|
|
or isinstance(base_trainer, GeneralTorchTrainer) or \
|
|
issubclass(type(base_trainer), GeneralTorchTrainer) or \
|
|
"can only copy instances of `GeneralMultiModelTrainer` " \
|
|
"and its subclasses, or " \
|
|
"`GeneralTorchTrainer` and its subclasses"
|
|
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
|
|
# Copy attributes from base_trainer one by one, skipping non-copyable objects
|
|
for key, value in base_trainer.__dict__.items():
|
|
try:
|
|
self.__dict__[key] = copy.deepcopy(value)
|
|
except TypeError:
|
|
self.__dict__[key] = value # If unable to deepcopy, use shallow copy
|
|
|
|
assert models_interact_mode in ["sequential", "parallel"], \
|
|
f"Invalid models_interact_mode, should be `sequential` or " \
|
|
f"`parallel`, but got {models_interact_mode}"
|
|
self.models_interact_mode = models_interact_mode
|
|
|
|
if int(model_nums) != model_nums or model_nums < 1:
|
|
raise ValueError(
|
|
f"model_nums should be integer and >= 1, got {model_nums}.")
|
|
self.model_nums = model_nums
|
|
|
|
self.ctx.cur_model_idx = 0 # used to mark cur model
|
|
|
|
# different internal models can have different hook_set
|
|
self.hooks_in_train_multiple_models = [self.hooks_in_train]
|
|
self.hooks_in_eval_multiple_models = [self.hooks_in_eval]
|
|
self.init_multiple_models()
|
|
self.init_multiple_model_hooks()
|
|
assert len(self.ctx.models) == model_nums == \
|
|
len(self.hooks_in_train_multiple_models) == len(
|
|
self.hooks_in_eval_multiple_models),\
|
|
"After init, len(hooks_in_train_multiple_models), " \
|
|
"len(hooks_in_eval_multiple_models), " \
|
|
"len(ctx.models) and model_nums should be the same"
|
|
|
|
def init_multiple_models(self):
|
|
"""
|
|
init multiple models and optimizers: the default implementation
|
|
is copy init manner;
|
|
========================= Extension =============================
|
|
users can override this function according to their own
|
|
requirements
|
|
"""
|
|
|
|
additional_models = [
|
|
copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1)
|
|
]
|
|
self.ctx.models = [self.ctx.model] + additional_models
|
|
|
|
self.ctx.optimizers = [
|
|
get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer)
|
|
for i in range(0, self.model_nums)
|
|
]
|
|
|
|
def register_multiple_model_hooks(self):
|
|
"""
|
|
By default, all internal models adopt the same hook_set.
|
|
|
|
Extension
|
|
Users can override this function to register customized hooks \
|
|
for different internal models.
|
|
|
|
Note:
|
|
- for sequential mode, users can append interact_hook on \
|
|
begin/end triggers such as \
|
|
" -> (on_fit_end, _interact_to_other_models) -> "
|
|
- for parallel mode, users can append interact_hook on any \
|
|
trigger they want such as \
|
|
" -> (on_xxx_point, _interact_to_other_models) -> "
|
|
- we must tell the running hooks which data_loader to \
|
|
call and which num_samples to count
|
|
"""
|
|
|
|
self.hooks_in_train_multiple_models.extend([
|
|
self.hooks_in_train_multiple_models[0]
|
|
for _ in range(1, self.model_nums)
|
|
])
|
|
self.hooks_in_eval_multiple_models.extend([
|
|
self.hooks_in_eval_multiple_models[0]
|
|
for _ in range(1, self.model_nums)
|
|
])
|
|
|
|
def init_multiple_model_hooks(self):
|
|
self.register_multiple_model_hooks()
|
|
if self.models_interact_mode == "sequential":
|
|
# hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores
|
|
# specific set for i-th internal model;
|
|
# for each dict, the key indicates point-in-time and the value
|
|
# indicates specific hook
|
|
self.hooks_in_train = self.hooks_in_train_multiple_models
|
|
self.hooks_in_eval = self.hooks_in_eval_multiple_models
|
|
elif self.models_interact_mode == "parallel":
|
|
# hooks_in_xxx is a dict whose key indicates point-in-time and
|
|
# value indicates specific hook
|
|
for trigger in list(self.hooks_in_train.keys()):
|
|
self.hooks_in_train[trigger] = []
|
|
self.hooks_in_eval[trigger] = []
|
|
for model_idx in range(len(self.ctx.models)):
|
|
self.hooks_in_train[trigger].extend(
|
|
self.hooks_in_train_multiple_models[model_idx]
|
|
[trigger])
|
|
self.hooks_in_train[trigger].extend(
|
|
[self._switch_model_ctx])
|
|
self.hooks_in_eval[trigger].extend(
|
|
self.hooks_in_eval_multiple_models[model_idx][trigger])
|
|
self.hooks_in_eval[trigger].extend(
|
|
[self._switch_model_ctx])
|
|
else:
|
|
raise RuntimeError(
|
|
f"Invalid models_interact_mode, should be `sequential` or "
|
|
f"`parallel`,"
|
|
f" but got {self.models_interact_mode}")
|
|
|
|
def register_hook_in_train(self,
|
|
new_hook,
|
|
trigger,
|
|
model_idx=0,
|
|
insert_pos=None,
|
|
base_hook=None,
|
|
insert_mode="before"):
|
|
hooks_dict = self.hooks_in_train_multiple_models[model_idx]
|
|
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
|
|
new_hook, trigger)
|
|
|
|
def register_hook_in_eval(self,
|
|
new_hook,
|
|
trigger,
|
|
model_idx=0,
|
|
insert_pos=None,
|
|
base_hook=None,
|
|
insert_mode="before"):
|
|
hooks_dict = self.hooks_in_eval_multiple_models[model_idx]
|
|
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
|
|
new_hook, trigger)
|
|
|
|
def _switch_model_ctx(self, next_model_idx=None):
|
|
if next_model_idx is None:
|
|
next_model_idx = (self.ctx.cur_model_idx + 1) % len(
|
|
self.ctx.models)
|
|
self.ctx.cur_model_idx = next_model_idx
|
|
self.ctx.model = self.ctx.models[next_model_idx]
|
|
self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
|
|
|
|
def _run_routine(self, mode, hooks_set, dataset_name=None):
|
|
"""Run the hooks_set and maintain the mode for multiple internal models
|
|
|
|
Arguments:
|
|
mode: running mode of client, chosen from train/val/test
|
|
|
|
Note:
|
|
Considering evaluation could be in ```hooks_set[
|
|
"on_epoch_end"]```, there could be two data loaders in \
|
|
self.ctx, we must tell the running hooks which data_loader to \
|
|
call and which num_samples to count
|
|
|
|
"""
|
|
num_samples_model = list()
|
|
if self.models_interact_mode == "sequential":
|
|
assert isinstance(hooks_set, list) and isinstance(hooks_set[0],
|
|
dict), \
|
|
"When models_interact_mode=sequential, " \
|
|
"hooks_set should be a list of dict" \
|
|
"hooks_set[i] stores specific set for i-th internal model." \
|
|
"For each dict, the key indicates point-in-time and the " \
|
|
"value indicates specific hook"
|
|
for model_idx in range(len(self.ctx.models)):
|
|
# switch different hooks & ctx for different internal models
|
|
hooks_set_model_i = hooks_set[model_idx]
|
|
self._switch_model_ctx(model_idx)
|
|
# [Interaction at run_routine level]
|
|
# one model runs its whole routine, then do sth. for
|
|
# interaction, then next model runs its whole routine
|
|
# ... -> run_routine_model_i
|
|
# -> _switch_model_ctx
|
|
# -> (on_fit_end, _interact_to_other_models)
|
|
# -> run_routine_model_i+1
|
|
# -> ...
|
|
num_samples = super()._run_routine(mode, hooks_set_model_i,
|
|
dataset_name)
|
|
num_samples_model.append(num_samples)
|
|
elif self.models_interact_mode == "parallel":
|
|
assert isinstance(hooks_set, dict), \
|
|
"When models_interact_mode=parallel, hooks_set should be a " \
|
|
"dict whose key indicates point-in-time and value indicates " \
|
|
"specific hook"
|
|
# [Interaction at point-in-time level]
|
|
# at a specific point-in-time, one model call hooks (including
|
|
# interaction), then next model call hooks
|
|
# ... -> (on_xxx_point, hook_xxx_model_i)
|
|
# -> (on_xxx_point, _interact_to_other_models)
|
|
# -> (on_xxx_point, _switch_model_ctx)
|
|
# -> (on_xxx_point, hook_xxx_model_i+1)
|
|
# -> ...
|
|
num_samples = super()._run_routine(mode, hooks_set, dataset_name)
|
|
num_samples_model.append(num_samples)
|
|
else:
|
|
raise RuntimeError(
|
|
f"Invalid models_interact_mode, should be `sequential` or "
|
|
f"`parallel`,"
|
|
f" but got {self.models_interact_mode}")
|
|
# For now, we return the average number of samples for different models
|
|
return np.mean(num_samples_model)
|
|
|
|
def get_model_para(self):
|
|
"""
|
|
return multiple model parameters
|
|
:return:
|
|
"""
|
|
trained_model_para = []
|
|
for model_idx in range(self.model_nums):
|
|
trained_model_para.append(
|
|
self._param_filter(
|
|
self.ctx.models[model_idx].cpu().state_dict()))
|
|
|
|
return trained_model_para[
|
|
0] if self.model_nums == 1 else trained_model_para
|
|
|
|
def update(self, model_parameters, strict=False):
|
|
# update multiple model paras
|
|
"""
|
|
Arguments:
|
|
model_parameters (list[dict]): Multiple pyTorch Module object's
|
|
state_dict.
|
|
"""
|
|
if self.model_nums == 1:
|
|
super().update(model_parameters, strict=strict)
|
|
else:
|
|
assert isinstance(model_parameters, list) and isinstance(
|
|
model_parameters[0], dict), \
|
|
"model_parameters should a list of multiple state_dict"
|
|
assert len(model_parameters) == self.model_nums, \
|
|
f"model_parameters should has the same length to " \
|
|
f"self.model_nums, " \
|
|
f"but got {len(model_parameters)} and {self.model_nums} " \
|
|
f"respectively"
|
|
for model_idx in range(self.model_nums):
|
|
self.ctx.models[model_idx].load_state_dict(self._param_filter(
|
|
model_parameters[model_idx]),
|
|
strict=strict)
|
|
|
|
def train(self, target_data_split_name="train"):
|
|
# return multiple model paras
|
|
sample_size, _, results = super().train(target_data_split_name)
|
|
|
|
return sample_size, self.get_model_para(), results
|