99 lines
3.2 KiB
Python
99 lines
3.2 KiB
Python
from federatedscope.register import register_trainer
|
|
from federatedscope.core.trainers import BaseTrainer
|
|
|
|
# An example for converting torch training process to FS training process
|
|
|
|
# Refer to `federatedscope.core.trainers.BaseTrainer` for interface.
|
|
|
|
# Try with FEMNIST:
|
|
# python federatedscope/main.py --cfg scripts/example_configs/femnist.yaml \
|
|
# trainer.type mytorchtrainer federate.sample_client_rate 0.01 \
|
|
# federate.total_round_num 5 eval.best_res_update_round_wise_key test_loss
|
|
|
|
|
|
class MyTorchTrainer(BaseTrainer):
|
|
def __init__(self, model, data, device, **kwargs):
|
|
import torch
|
|
# NN modules
|
|
self.model = model
|
|
# FS `ClientData` or your own data
|
|
self.data = data
|
|
# Device name
|
|
self.device = device
|
|
# kwargs
|
|
self.kwargs = kwargs
|
|
# Criterion & Optimizer
|
|
self.criterion = torch.nn.CrossEntropyLoss()
|
|
|
|
def train(self):
|
|
import torch
|
|
self.optimizer = torch.optim.SGD(self.model.parameters(),
|
|
lr=0.001,
|
|
momentum=0.9,
|
|
weight_decay=1e-4)
|
|
|
|
# _hook_on_fit_start_init
|
|
self.model.to(self.device)
|
|
self.model.train()
|
|
|
|
total_loss = num_samples = 0
|
|
# _hook_on_batch_start_init
|
|
for x, y in self.data['train']:
|
|
# _hook_on_batch_forward
|
|
x, y = x.to(self.device), y.to(self.device)
|
|
outputs = self.model(x)
|
|
loss = self.criterion(outputs, y)
|
|
|
|
# _hook_on_batch_backward
|
|
self.optimizer.zero_grad()
|
|
loss.backward()
|
|
self.optimizer.step()
|
|
|
|
# _hook_on_batch_end
|
|
total_loss += loss.item() * y.shape[0]
|
|
num_samples += y.shape[0]
|
|
|
|
# _hook_on_fit_end
|
|
return num_samples, self.model.cpu().state_dict(), \
|
|
{'loss_total': total_loss, 'avg_loss': total_loss/float(
|
|
num_samples)}
|
|
|
|
def evaluate(self, target_data_split_name='test'):
|
|
import torch
|
|
with torch.no_grad():
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
total_loss = num_samples = 0
|
|
# _hook_on_batch_start_init
|
|
for x, y in self.data[target_data_split_name]:
|
|
# _hook_on_batch_forward
|
|
x, y = x.to(self.device), y.to(self.device)
|
|
pred = self.model(x)
|
|
loss = self.criterion(pred, y)
|
|
|
|
# _hook_on_batch_end
|
|
total_loss += loss.item() * y.shape[0]
|
|
num_samples += y.shape[0]
|
|
|
|
# _hook_on_fit_end
|
|
return {
|
|
f'{target_data_split_name}_loss': total_loss,
|
|
f'{target_data_split_name}_total': num_samples,
|
|
f'{target_data_split_name}_avg_loss': total_loss /
|
|
float(num_samples)
|
|
}
|
|
|
|
def update(self, model_parameters, strict=False):
|
|
self.model.load_state_dict(model_parameters, strict)
|
|
|
|
def get_model_para(self):
|
|
return self.model.cpu().state_dict()
|
|
|
|
|
|
def call_my_torch_trainer(trainer_type):
|
|
if trainer_type == 'mytorchtrainer':
|
|
return MyTorchTrainer
|
|
|
|
|
|
register_trainer('mytorchtrainer', call_my_torch_trainer)
|