228 lines
10 KiB
Python
228 lines
10 KiB
Python
try:
|
|
import tensorflow as tf
|
|
except ImportError:
|
|
tf = None
|
|
|
|
import numpy as np
|
|
from federatedscope.core.trainers import Trainer
|
|
from federatedscope.core.trainers.enums import MODE
|
|
from federatedscope.core.auxiliaries.utils import batch_iter
|
|
from federatedscope.core.trainers.context import CtxVar
|
|
from federatedscope.core.trainers.enums import LIFECYCLE
|
|
|
|
|
|
class GeneralTFTrainer(Trainer):
|
|
def train(self, target_data_split_name="train", hooks_set=None):
|
|
hooks_set = self.hooks_in_train if hooks_set is None else hooks_set
|
|
|
|
self.ctx.check_split(target_data_split_name)
|
|
|
|
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
|
|
target_data_split_name)
|
|
|
|
# TODO: The return values should be more flexible? Now: sample_num,
|
|
# model_para, results={k:v}
|
|
|
|
return num_samples, self.ctx.model.state_dict(), self.ctx.eval_metrics
|
|
|
|
def parse_data(self, data):
|
|
"""Populate "{}_data", "{}_loader" and "num_{}_data" for different
|
|
modes
|
|
"""
|
|
init_dict = dict()
|
|
if isinstance(data, dict):
|
|
for mode in ["train", "val", "test"]:
|
|
init_dict["{}_data".format(mode)] = None
|
|
init_dict["{}_loader".format(mode)] = None
|
|
init_dict["num_{}_data".format(mode)] = 0
|
|
if data.get(mode, None) is not None:
|
|
init_dict["{}_data".format(mode)] = data.get(mode)
|
|
init_dict["num_{}_data".format(mode)] = len(data.get(mode))
|
|
else:
|
|
raise TypeError("Type of data should be dict.")
|
|
return init_dict
|
|
|
|
def register_default_hooks_train(self):
|
|
self.register_hook_in_train(self._hook_on_fit_start_init,
|
|
"on_fit_start")
|
|
self.register_hook_in_train(self._hook_on_epoch_start,
|
|
"on_epoch_start")
|
|
self.register_hook_in_train(self._hook_on_batch_start_init,
|
|
"on_batch_start")
|
|
self.register_hook_in_train(self._hook_on_batch_forward,
|
|
"on_batch_forward")
|
|
self.register_hook_in_train(self._hook_on_batch_forward_regularizer,
|
|
"on_batch_forward")
|
|
self.register_hook_in_train(self._hook_on_batch_backward,
|
|
"on_batch_backward")
|
|
self.register_hook_in_train(self._hook_on_batch_end, "on_batch_end")
|
|
self.register_hook_in_train(self._hook_on_fit_end, "on_fit_end")
|
|
|
|
def register_default_hooks_eval(self):
|
|
# test/val
|
|
self.register_hook_in_eval(self._hook_on_fit_start_init,
|
|
"on_fit_start")
|
|
self.register_hook_in_eval(self._hook_on_epoch_start, "on_epoch_start")
|
|
self.register_hook_in_eval(self._hook_on_batch_start_init,
|
|
"on_batch_start")
|
|
self.register_hook_in_eval(self._hook_on_batch_forward,
|
|
"on_batch_forward")
|
|
self.register_hook_in_eval(self._hook_on_batch_end, "on_batch_end")
|
|
self.register_hook_in_eval(self._hook_on_fit_end, "on_fit_end")
|
|
|
|
def _hook_on_fit_start_init(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.model`` Move to `ctx.device`
|
|
``ctx.loss_batch_total`` Initialize to 0
|
|
``ctx.loss_regular_total`` Initialize to 0
|
|
``ctx.num_samples`` Initialize to 0
|
|
``ctx.ys_true`` Initialize to ``[]``
|
|
``ctx.ys_prob`` Initialize to ``[]``
|
|
================================== ===========================
|
|
"""
|
|
# prepare model
|
|
ctx.model.to(ctx.device)
|
|
|
|
# prepare statistics
|
|
ctx.loss_batch_total = CtxVar(0., LIFECYCLE.ROUTINE)
|
|
ctx.loss_regular_total = CtxVar(0., LIFECYCLE.ROUTINE)
|
|
ctx.num_samples = CtxVar(0, LIFECYCLE.ROUTINE)
|
|
ctx.ys_true = CtxVar([], LIFECYCLE.ROUTINE)
|
|
ctx.ys_prob = CtxVar([], LIFECYCLE.ROUTINE)
|
|
|
|
def _hook_on_epoch_start(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.{cur_split}_loader`` Initialize DataLoader
|
|
================================== ===========================
|
|
"""
|
|
# prepare dataloader
|
|
setattr(ctx, "{}_loader".format(ctx.cur_split),
|
|
batch_iter(ctx.get("{}_data".format(ctx.cur_split))))
|
|
|
|
def _hook_on_batch_start_init(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.data_batch`` Initialize batch data
|
|
================================== ===========================
|
|
"""
|
|
# prepare data batch
|
|
try:
|
|
ctx.data_batch = next(ctx.get("{}_loader".format(ctx.cur_split)))
|
|
except StopIteration:
|
|
raise StopIteration
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.optimizer`` Initialize optimizer
|
|
``ctx.batch_size`` Calculate batch size
|
|
``ctx.loss_batch`` Calculate batch loss
|
|
``ctx.model`` Forward propagation
|
|
``ctx.y_true`` Get y_true from batch
|
|
``ctx.y_prob`` Forward propagation to get \
|
|
`y_prob`
|
|
================================== ===========================
|
|
"""
|
|
ctx.optimizer = ctx.model.optimizer
|
|
|
|
ctx.batch_size = len(ctx.data_batch)
|
|
|
|
with ctx.model.graph.as_default():
|
|
with ctx.model.sess.as_default():
|
|
feed_dict = {
|
|
ctx.model.input_x: ctx.data_batch['x'],
|
|
ctx.model.input_y: ctx.data_batch['y']
|
|
}
|
|
_, batch_loss, y_true, y_prob = ctx.model.sess.run(
|
|
[
|
|
ctx.model.train_op, ctx.model.losses,
|
|
ctx.model.input_y, ctx.model.out
|
|
],
|
|
feed_dict=feed_dict)
|
|
ctx.loss_batch = batch_loss
|
|
ctx.y_true = CtxVar(y_true, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(y_prob, LIFECYCLE.BATCH)
|
|
|
|
def _hook_on_batch_forward_regularizer(self, ctx):
|
|
pass
|
|
|
|
def _hook_on_batch_backward(self, ctx):
|
|
pass
|
|
|
|
def _hook_on_batch_end(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.num_samples`` Add ``ctx.batch_size``
|
|
``ctx.loss_batch_total`` Add batch loss
|
|
``ctx.loss_regular_total`` Add batch regular loss
|
|
``ctx.ys_true`` Append ``ctx.y_true``
|
|
``ctx.ys_prob`` Append ``ctx.ys_prob``
|
|
================================== ===========================
|
|
"""
|
|
# TODO: the same with the torch_trainer
|
|
# update statistics
|
|
ctx.num_samples += ctx.batch_size
|
|
ctx.loss_batch_total += ctx.loss_batch
|
|
ctx.loss_regular_total += float(ctx.get("loss_regular", 0.))
|
|
|
|
# cache label for evaluate
|
|
ctx.ys_true.append(ctx.y_true.detach().cpu().numpy())
|
|
ctx.ys_prob.append(ctx.y_prob.detach().cpu().numpy())
|
|
|
|
def _hook_on_fit_end(self, ctx):
|
|
"""
|
|
Evaluate metrics.
|
|
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.ys_true`` Convert to `numpy.array`
|
|
``ctx.ys_prob`` Convert to `numpy.array`
|
|
``ctx.monitor`` Evaluate the results
|
|
``ctx.eval_metrics`` Get evaluated results from \
|
|
``ctx.monitor``
|
|
================================== ===========================
|
|
"""
|
|
ctx.ys_true = CtxVar(np.concatenate(ctx.ys_true), LIFECYCLE.ROUTINE)
|
|
ctx.ys_prob = CtxVar(np.concatenate(ctx.ys_prob), LIFECYCLE.ROUTINE)
|
|
results = self.ctx.monitor.eval(ctx)
|
|
setattr(ctx, 'eval_metrics', results)
|
|
|
|
def update(self, model_parameters, strict=False):
|
|
# TODO: Fix TF in Distributed mode
|
|
# 1) `pickle.loads((base64.b64decode(enc_tensor)))` to restore
|
|
self.ctx.model.load_state_dict(model_parameters, strict=strict)
|
|
|
|
def save_model(self, path, cur_round=-1):
|
|
pass
|
|
|
|
def load_model(self, path):
|
|
pass
|
|
|
|
def discharge_model(self):
|
|
pass
|