242 lines
11 KiB
Python
242 lines
11 KiB
Python
from typing import Type
|
|
|
|
import numpy as np
|
|
import torch
|
|
from torch.nn.functional import softmax as f_softmax
|
|
|
|
from federatedscope.core.trainers.enums import LIFECYCLE
|
|
from federatedscope.core.trainers.context import CtxVar
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer
|
|
from federatedscope.core.trainers.trainer_multi_model import \
|
|
GeneralMultiModelTrainer
|
|
|
|
# Modifications:
|
|
# 1. Fix issue where the trainer cannot access the monitor. Line 31
|
|
# (2024-10-6, czzhangheng)
|
|
|
|
|
|
class FedEMTrainer(GeneralMultiModelTrainer):
|
|
"""
|
|
The FedEM implementation, "Federated Multi-Task Learning under a \
|
|
Mixture of Distributions (NeurIPS 2021)" \
|
|
based on the Algorithm 1 in their paper and official codes:
|
|
https://github.com/omarfoq/FedEM
|
|
"""
|
|
def __init__(self,
|
|
model_nums,
|
|
models_interact_mode="sequential",
|
|
model=None,
|
|
data=None,
|
|
device=None,
|
|
config=None,
|
|
monitor=None,
|
|
base_trainer: Type[GeneralTorchTrainer] = None):
|
|
super(FedEMTrainer,
|
|
self).__init__(model_nums, models_interact_mode, model, data,
|
|
device, config, monitor=monitor)
|
|
device = self.ctx.device
|
|
|
|
# --------------- attribute-level modifications ----------------------
|
|
# used to mixture the internal models
|
|
self.weights_internal_models = (torch.ones(self.model_nums) /
|
|
self.model_nums).to(device)
|
|
self.weights_data_sample = (
|
|
torch.ones(self.model_nums, self.ctx.num_train_batch) /
|
|
self.model_nums).to(device)
|
|
|
|
self.ctx.all_losses_model_batch = torch.zeros(
|
|
self.model_nums, self.ctx.num_train_batch).to(device)
|
|
self.ctx.cur_batch_idx = -1
|
|
# `ctx[f"{cur_data}_y_prob_ensemble"] = 0` in
|
|
# func `_hook_on_fit_end_ensemble_eval`
|
|
# -> self.ctx.test_y_prob_ensemble = 0
|
|
# -> self.ctx.train_y_prob_ensemble = 0
|
|
# -> self.ctx.val_y_prob_ensemble = 0
|
|
|
|
# ---------------- action-level modifications -----------------------
|
|
# see register_multiple_model_hooks(),
|
|
# which is called in the __init__ of `GeneralMultiModelTrainer`
|
|
|
|
def register_multiple_model_hooks(self):
|
|
"""
|
|
customized multiple_model_hooks, which is called
|
|
in the __init__ of `GeneralMultiModelTrainer`
|
|
"""
|
|
# First register hooks for model 0
|
|
# ---------------- train hooks -----------------------
|
|
self.register_hook_in_train(
|
|
new_hook=self._hook_on_fit_start_mixture_weights_update,
|
|
trigger="on_fit_start",
|
|
insert_pos=0) # insert at the front
|
|
self.register_hook_in_train(
|
|
new_hook=self._hook_on_fit_start_flop_count,
|
|
trigger="on_fit_start",
|
|
insert_pos=1 # follow the mixture operation
|
|
)
|
|
self.register_hook_in_train(new_hook=self._hook_on_fit_end_flop_count,
|
|
trigger="on_fit_end",
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(
|
|
new_hook=self._hook_on_batch_forward_weighted_loss,
|
|
trigger="on_batch_forward",
|
|
insert_pos=-1)
|
|
self.register_hook_in_train(
|
|
new_hook=self._hook_on_batch_start_track_batch_idx,
|
|
trigger="on_batch_start",
|
|
insert_pos=0) # insert at the front
|
|
# ---------------- eval hooks -----------------------
|
|
self.register_hook_in_eval(
|
|
new_hook=self._hook_on_batch_end_gather_loss,
|
|
trigger="on_batch_end",
|
|
insert_pos=0
|
|
) # insert at the front, (we need gather the loss before clean it)
|
|
self.register_hook_in_eval(
|
|
new_hook=self._hook_on_batch_start_track_batch_idx,
|
|
trigger="on_batch_start",
|
|
insert_pos=0) # insert at the front
|
|
# replace the original evaluation into the ensemble one
|
|
self.replace_hook_in_eval(new_hook=self._hook_on_fit_end_ensemble_eval,
|
|
target_trigger="on_fit_end",
|
|
target_hook_name="_hook_on_fit_end")
|
|
|
|
# Then for other models, set the same hooks as model 0
|
|
# since we differentiate different models in the hook
|
|
# implementations via ctx.cur_model_idx
|
|
self.hooks_in_train_multiple_models.extend([
|
|
self.hooks_in_train_multiple_models[0]
|
|
for _ in range(1, self.model_nums)
|
|
])
|
|
self.hooks_in_eval_multiple_models.extend([
|
|
self.hooks_in_eval_multiple_models[0]
|
|
for _ in range(1, self.model_nums)
|
|
])
|
|
|
|
def _hook_on_batch_start_track_batch_idx(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.optimizer_for_global_model`` False
|
|
================================== ===========================
|
|
"""
|
|
# for both train & eval
|
|
ctx.cur_batch_idx = (self.ctx.cur_batch_idx +
|
|
1) % self.ctx.num_train_batch
|
|
|
|
def _hook_on_batch_forward_weighted_loss(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.loss_batch`` Multiply by \
|
|
``weights_internal_models``
|
|
================================== ===========================
|
|
"""
|
|
# for only train
|
|
ctx.loss_batch *= self.weights_internal_models[ctx.cur_model_idx]
|
|
|
|
def _hook_on_batch_end_gather_loss(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.all_losses_model_batch`` Gather loss
|
|
================================== ===========================
|
|
"""
|
|
# for only eval
|
|
# before clean the loss_batch; we record it
|
|
# for further weights_data_sample update
|
|
ctx.all_losses_model_batch[ctx.cur_model_idx][
|
|
ctx.cur_batch_idx] = ctx.loss_batch.item()
|
|
|
|
def _hook_on_fit_start_mixture_weights_update(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.mode`` Evaluate
|
|
================================== ===========================
|
|
"""
|
|
# for only train
|
|
if ctx.cur_model_idx != 0:
|
|
# do the mixture_weights_update once
|
|
pass
|
|
else:
|
|
# gathers losses for all sample in iterator
|
|
# for each internal model, calling `evaluate()`
|
|
for model_idx in range(self.model_nums):
|
|
self._switch_model_ctx(model_idx)
|
|
self.evaluate(target_data_split_name="train")
|
|
|
|
self.weights_data_sample = f_softmax(
|
|
(torch.log(self.weights_internal_models) -
|
|
ctx.all_losses_model_batch.T),
|
|
dim=1).T
|
|
self.weights_internal_models = self.weights_data_sample.mean(dim=1)
|
|
|
|
# restore the model_ctx
|
|
self._switch_model_ctx(0)
|
|
|
|
def _hook_on_fit_start_flop_count(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.monitor`` Count total_flops
|
|
================================== ===========================
|
|
"""
|
|
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
|
|
self.model_nums * ctx.num_train_data
|
|
|
|
def _hook_on_fit_end_flop_count(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.monitor`` Count total_flops
|
|
================================== ===========================
|
|
"""
|
|
self.ctx.monitor.total_flops += self.ctx.monitor.flops_per_sample * \
|
|
self.model_nums * ctx.num_train_data
|
|
|
|
def _hook_on_fit_end_ensemble_eval(self, ctx):
|
|
"""
|
|
Ensemble evaluation
|
|
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.ys_prob_ensemble`` Ensemble ys_prob
|
|
``ctx.ys_true`` Concatenate results
|
|
``ctx.ys_prob`` Concatenate results
|
|
``ctx.eval_metrics`` Get evaluated results from \
|
|
``ctx.monitor``
|
|
================================== ===========================
|
|
"""
|
|
if ctx.get("ys_prob_ensemble", None) is None:
|
|
ctx.ys_prob_ensemble = CtxVar(0, LIFECYCLE.ROUTINE)
|
|
ctx.ys_prob_ensemble += np.concatenate(
|
|
ctx.ys_prob) * self.weights_internal_models[
|
|
ctx.cur_model_idx].item()
|
|
|
|
# do metrics calculation after the last internal model evaluation done
|
|
if ctx.cur_model_idx == self.model_nums - 1:
|
|
ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true),
|
|
LIFECYCLE.ROUTINE)
|
|
ctx.ys_prob = ctx.ys_prob_ensemble
|
|
ctx.eval_metrics = self.ctx.monitor.eval(ctx)
|