63 lines
2.6 KiB
Python
63 lines
2.6 KiB
Python
import numpy as np
|
|
import torch
|
|
|
|
from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer as Trainer
|
|
from federatedscope.trafficflow.dataset.normalization import StandardScaler
|
|
from federatedscope.core.trainers.enums import MODE, LIFECYCLE
|
|
from federatedscope.core.trainers.context import Context, CtxVar, lifecycle
|
|
|
|
def print_model_parameters(model):
|
|
print("Model parameters and their shapes:")
|
|
for name, param in model.named_parameters():
|
|
print(f"{name}: {param.shape}")
|
|
|
|
class TrafficflowTrainer(Trainer):
|
|
def __init__(self, model, scaler, args, data, device,
|
|
monitor):
|
|
super().__init__(model, data, device, args, monitor=monitor)
|
|
self.scaler = StandardScaler(scaler[0], scaler[1])
|
|
|
|
def train(self, target_data_split_name="train", hooks_set=None):
|
|
hooks_set = hooks_set or self.hooks_in_train
|
|
|
|
self.ctx.check_split(target_data_split_name)
|
|
|
|
num_samples = self._run_routine(MODE.TRAIN, hooks_set,
|
|
target_data_split_name)
|
|
|
|
train_loss = self.ctx.eval_metrics
|
|
val_loss = self.evaluate('val')
|
|
test_loss = self.evaluate('test')
|
|
all_metrics = {'train_loss': train_loss['train_avg_loss'],
|
|
'val_loss': val_loss['val_avg_loss'],
|
|
'test_loss': test_loss['test_avg_loss'],
|
|
}
|
|
self.ctx.eval_metrics = all_metrics
|
|
|
|
return num_samples, self.get_model_para(), self.ctx.eval_metrics
|
|
|
|
def _hook_on_batch_forward(self, ctx):
|
|
"""
|
|
Note:
|
|
The modified attributes and according operations are shown below:
|
|
================================== ===========================
|
|
Attribute Operation
|
|
================================== ===========================
|
|
``ctx.y_true`` Move to `ctx.device`
|
|
``ctx.y_prob`` Forward propagation get y_prob
|
|
``ctx.loss_batch`` Calculate the loss
|
|
``ctx.batch_size`` Get the batch_size
|
|
================================== ===========================
|
|
"""
|
|
x, label = [_.to(ctx.device) for _ in ctx.data_batch]
|
|
pred = ctx.model(x)
|
|
pred = self.scaler.inverse_transform(pred)
|
|
|
|
if len(label.size()) == 0:
|
|
label = label.unsqueeze(0)
|
|
|
|
ctx.y_true = CtxVar(label, LIFECYCLE.BATCH)
|
|
ctx.y_prob = CtxVar(pred, LIFECYCLE.BATCH)
|
|
ctx.loss_batch = CtxVar(ctx.criterion(pred, label), LIFECYCLE.BATCH)
|
|
ctx.batch_size = CtxVar(len(label), LIFECYCLE.BATCH)
|