import numpy as np import torch from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer as Trainer from federatedscope.trafficflow.dataset.normalization import StandardScaler from federatedscope.core.trainers.enums import MODE, LIFECYCLE from federatedscope.core.trainers.context import Context, CtxVar, lifecycle def print_model_parameters(model): print("Model parameters and their shapes:") for name, param in model.named_parameters(): print(f"{name}: {param.shape}") class TrafficflowTrainer(Trainer): def __init__(self, model, scaler, args, data, device, monitor): super().__init__(model, data, device, args, monitor=monitor) self.scaler = StandardScaler(scaler[0], scaler[1]) def train(self, target_data_split_name="train", hooks_set=None): hooks_set = hooks_set or self.hooks_in_train self.ctx.check_split(target_data_split_name) num_samples = self._run_routine(MODE.TRAIN, hooks_set, target_data_split_name) train_loss = self.ctx.eval_metrics val_loss = self.evaluate('val') test_loss = self.evaluate('test') all_metrics = {'train_loss': train_loss['train_avg_loss'], 'val_loss': val_loss['val_avg_loss'], 'test_loss': test_loss['test_avg_loss'], } self.ctx.eval_metrics = all_metrics return num_samples, self.get_model_para(), self.ctx.eval_metrics def _hook_on_batch_forward(self, ctx): """ Note: The modified attributes and according operations are shown below: ================================== =========================== Attribute Operation ================================== =========================== ``ctx.y_true`` Move to `ctx.device` ``ctx.y_prob`` Forward propagation get y_prob ``ctx.loss_batch`` Calculate the loss ``ctx.batch_size`` Get the batch_size ================================== =========================== """ x, label = [_.to(ctx.device) for _ in ctx.data_batch] pred = ctx.model(x) pred = self.scaler.inverse_transform(pred) if len(label.size()) == 0: label = label.unsqueeze(0) ctx.y_true = CtxVar(label, LIFECYCLE.BATCH) ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH) ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH) ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)