FS-TFP/federatedscope/core/trainers/trainer_multi_model.py

324 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
from types import FunctionType
from typing import Type
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
import numpy as np
# Modifications:
# 1. Fix issue where the trainer cannot access the monitor. Line 68
# 2. Fix issue where deepcopy cannot copy items Line 77
# (2024-10-6, czzhangheng)
class GeneralMultiModelTrainer(GeneralTorchTrainer):
def __init__(self,
model_nums,
models_interact_mode="sequential",
model=None,
data=None,
device=None,
config=None,
monitor=None,
base_trainer: Type[GeneralTorchTrainer] = None):
"""
`GeneralMultiModelTrainer` supports train/eval via multiple
internal models
Arguments:
model_nums (int): how many internal models and optimizers
will be held by the trainer
models_interact_mode (str): how the models interact, can be
"sequential" or "parallel".
model: training model
data: a dict contains train/val/test data
device: device to run
config: for trainer-related configuration
base_trainer: if given, the GeneralMultiModelTrainer init
will based on base_trainer copy
The sequential mode indicates the interaction at
run_routine level
[one model runs its whole routine, then do sth. for
interaction, then next model runs its whole routine]
... -> run_routine_model_i
-> _switch_model_ctx
-> (on_fit_end, _interact_to_other_models)
-> run_routine_model_i+1
-> ...
The parallel mode indicates the interaction
at point-in-time level
[At a specific point-in-time, one model call hooks (
including interaction), then next model call hooks]
... -> (on_xxx_point, hook_xxx_model_i)
-> (on_xxx_point, _interact_to_other_models)
-> (on_xxx_point, _switch_model_ctx)
-> (on_xxx_point, hook_xxx_model_i+1)
-> ...
"""
# support two initialization methods for the `GeneralMultiModelTrainer`
# 1) from another trainer; or 2) standard init manner given (model,
# data, device, config)
if base_trainer is None:
assert model is not None and \
data is not None and \
device is not None and \
config is not None, "when not copy construction, (model, " \
"data, device, config) should not be " \
"None"
super(GeneralMultiModelTrainer,
self).__init__(model, data, device, config, monitor=monitor)
else:
assert isinstance(base_trainer, GeneralMultiModelTrainer) or \
issubclass(type(base_trainer), GeneralMultiModelTrainer) \
or isinstance(base_trainer, GeneralTorchTrainer) or \
issubclass(type(base_trainer), GeneralTorchTrainer) or \
"can only copy instances of `GeneralMultiModelTrainer` " \
"and its subclasses, or " \
"`GeneralTorchTrainer` and its subclasses"
# self.__dict__ = copy.deepcopy(base_trainer.__dict__)
# 逐个复制 base_trainer 的属性,跳过不可拷贝的对象
for key, value in base_trainer.__dict__.items():
try:
self.__dict__[key] = copy.deepcopy(value)
except TypeError:
self.__dict__[key] = value # 如果不能 deepcopy使用浅拷贝
assert models_interact_mode in ["sequential", "parallel"], \
f"Invalid models_interact_mode, should be `sequential` or " \
f"`parallel`, but got {models_interact_mode}"
self.models_interact_mode = models_interact_mode
if int(model_nums) != model_nums or model_nums < 1:
raise ValueError(
f"model_nums should be integer and >= 1, got {model_nums}.")
self.model_nums = model_nums
self.ctx.cur_model_idx = 0 # used to mark cur model
# different internal models can have different hook_set
self.hooks_in_train_multiple_models = [self.hooks_in_train]
self.hooks_in_eval_multiple_models = [self.hooks_in_eval]
self.init_multiple_models()
self.init_multiple_model_hooks()
assert len(self.ctx.models) == model_nums == \
len(self.hooks_in_train_multiple_models) == len(
self.hooks_in_eval_multiple_models),\
"After init, len(hooks_in_train_multiple_models), " \
"len(hooks_in_eval_multiple_models), " \
"len(ctx.models) and model_nums should be the same"
def init_multiple_models(self):
"""
init multiple models and optimizers: the default implementation
is copy init manner;
========================= Extension =============================
users can override this function according to their own
requirements
"""
additional_models = [
copy.deepcopy(self.ctx.model) for _ in range(self.model_nums - 1)
]
self.ctx.models = [self.ctx.model] + additional_models
self.ctx.optimizers = [
get_optimizer(self.ctx.models[i], **self.cfg.train.optimizer)
for i in range(0, self.model_nums)
]
def register_multiple_model_hooks(self):
"""
By default, all internal models adopt the same hook_set.
Extension
Users can override this function to register customized hooks \
for different internal models.
Note:
- for sequential mode, users can append interact_hook on \
begin/end triggers such as \
" -> (on_fit_end, _interact_to_other_models) -> "
- for parallel mode, users can append interact_hook on any \
trigger they want such as \
" -> (on_xxx_point, _interact_to_other_models) -> "
- we must tell the running hooks which data_loader to \
call and which num_samples to count
"""
self.hooks_in_train_multiple_models.extend([
self.hooks_in_train_multiple_models[0]
for _ in range(1, self.model_nums)
])
self.hooks_in_eval_multiple_models.extend([
self.hooks_in_eval_multiple_models[0]
for _ in range(1, self.model_nums)
])
def init_multiple_model_hooks(self):
self.register_multiple_model_hooks()
if self.models_interact_mode == "sequential":
# hooks_in_xxx is a list of dict, hooks_in_xxx[i] stores
# specific set for i-th internal model;
# for each dict, the key indicates point-in-time and the value
# indicates specific hook
self.hooks_in_train = self.hooks_in_train_multiple_models
self.hooks_in_eval = self.hooks_in_eval_multiple_models
elif self.models_interact_mode == "parallel":
# hooks_in_xxx is a dict whose key indicates point-in-time and
# value indicates specific hook
for trigger in list(self.hooks_in_train.keys()):
self.hooks_in_train[trigger] = []
self.hooks_in_eval[trigger] = []
for model_idx in range(len(self.ctx.models)):
self.hooks_in_train[trigger].extend(
self.hooks_in_train_multiple_models[model_idx]
[trigger])
self.hooks_in_train[trigger].extend(
[self._switch_model_ctx])
self.hooks_in_eval[trigger].extend(
self.hooks_in_eval_multiple_models[model_idx][trigger])
self.hooks_in_eval[trigger].extend(
[self._switch_model_ctx])
else:
raise RuntimeError(
f"Invalid models_interact_mode, should be `sequential` or "
f"`parallel`,"
f" but got {self.models_interact_mode}")
def register_hook_in_train(self,
new_hook,
trigger,
model_idx=0,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_train_multiple_models[model_idx]
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def register_hook_in_eval(self,
new_hook,
trigger,
model_idx=0,
insert_pos=None,
base_hook=None,
insert_mode="before"):
hooks_dict = self.hooks_in_eval_multiple_models[model_idx]
self._register_hook(base_hook, hooks_dict, insert_mode, insert_pos,
new_hook, trigger)
def _switch_model_ctx(self, next_model_idx=None):
if next_model_idx is None:
next_model_idx = (self.ctx.cur_model_idx + 1) % len(
self.ctx.models)
self.ctx.cur_model_idx = next_model_idx
self.ctx.model = self.ctx.models[next_model_idx]
self.ctx.optimizer = self.ctx.optimizers[next_model_idx]
def _run_routine(self, mode, hooks_set, dataset_name=None):
"""Run the hooks_set and maintain the mode for multiple internal models
Arguments:
mode: running mode of client, chosen from train/val/test
Note:
Considering evaluation could be in ```hooks_set[
"on_epoch_end"]```, there could be two data loaders in \
self.ctx, we must tell the running hooks which data_loader to \
call and which num_samples to count
"""
num_samples_model = list()
if self.models_interact_mode == "sequential":
assert isinstance(hooks_set, list) and isinstance(hooks_set[0],
dict), \
"When models_interact_mode=sequential, " \
"hooks_set should be a list of dict" \
"hooks_set[i] stores specific set for i-th internal model." \
"For each dict, the key indicates point-in-time and the " \
"value indicates specific hook"
for model_idx in range(len(self.ctx.models)):
# switch different hooks & ctx for different internal models
hooks_set_model_i = hooks_set[model_idx]
self._switch_model_ctx(model_idx)
# [Interaction at run_routine level]
# one model runs its whole routine, then do sth. for
# interaction, then next model runs its whole routine
# ... -> run_routine_model_i
# -> _switch_model_ctx
# -> (on_fit_end, _interact_to_other_models)
# -> run_routine_model_i+1
# -> ...
num_samples = super()._run_routine(mode, hooks_set_model_i,
dataset_name)
num_samples_model.append(num_samples)
elif self.models_interact_mode == "parallel":
assert isinstance(hooks_set, dict), \
"When models_interact_mode=parallel, hooks_set should be a " \
"dict whose key indicates point-in-time and value indicates " \
"specific hook"
# [Interaction at point-in-time level]
# at a specific point-in-time, one model call hooks (including
# interaction), then next model call hooks
# ... -> (on_xxx_point, hook_xxx_model_i)
# -> (on_xxx_point, _interact_to_other_models)
# -> (on_xxx_point, _switch_model_ctx)
# -> (on_xxx_point, hook_xxx_model_i+1)
# -> ...
num_samples = super()._run_routine(mode, hooks_set, dataset_name)
num_samples_model.append(num_samples)
else:
raise RuntimeError(
f"Invalid models_interact_mode, should be `sequential` or "
f"`parallel`,"
f" but got {self.models_interact_mode}")
# For now, we return the average number of samples for different models
return np.mean(num_samples_model)
def get_model_para(self):
"""
return multiple model parameters
:return:
"""
trained_model_para = []
for model_idx in range(self.model_nums):
trained_model_para.append(
self._param_filter(
self.ctx.models[model_idx].cpu().state_dict()))
return trained_model_para[
0] if self.model_nums == 1 else trained_model_para
def update(self, model_parameters, strict=False):
# update multiple model paras
"""
Arguments:
model_parameters (list[dict]): Multiple pyTorch Module object's
state_dict.
"""
if self.model_nums == 1:
super().update(model_parameters, strict=strict)
else:
assert isinstance(model_parameters, list) and isinstance(
model_parameters[0], dict), \
"model_parameters should a list of multiple state_dict"
assert len(model_parameters) == self.model_nums, \
f"model_parameters should has the same length to " \
f"self.model_nums, " \
f"but got {len(model_parameters)} and {self.model_nums} " \
f"respectively"
for model_idx in range(self.model_nums):
self.ctx.models[model_idx].load_state_dict(self._param_filter(
model_parameters[model_idx]),
strict=strict)
def train(self, target_data_split_name="train"):
# return multiple model paras
sample_size, _, results = super().train(target_data_split_name)
return sample_size, self.get_model_para(), results