FS-TFP/federatedscope/contrib/trainer/local_entropy.py

146 lines
4.8 KiB
Python

import math
from collections import defaultdict
import torch
from federatedscope.core.trainers import BaseTrainer
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
def copy_params(src):
tgt = dict()
for name, t in src.named_parameters():
if t.requires_grad:
tgt[name] = t.detach().clone()
return tgt
def prox_term(cur, last):
loss = .0
for name, w in cur.named_parameters():
loss += 0.5 * torch.sum((w - last[name])**2)
return loss
def add_noise(model, sigma):
for p in model.parameters():
if p.requires_grad:
p.data += sigma * torch.randn(size=p.shape, device=p.device)
def moving_avg(cur, new, alpha):
for k, v in cur.items():
v.data = (1 - alpha) * v + alpha * new[k]
class LocalEntropyTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# configs
self.kwargs = kwargs
self.config = kwargs['config']
self.optim_config = self.config.train.optimizer
self.local_entropy_config = self.config.trainer.local_entropy
self._thermal = self.local_entropy_config.gamma
def train(self):
# Criterion & Optimizer
criterion = torch.nn.CrossEntropyLoss().to(self.device)
optimizer = get_optimizer(self.model, **self.optim_config)
# _hook_on_fit_start_init
self.model.to(self.device)
current_global_model = copy_params(self.model)
mu = copy_params(self.model)
self.model.train()
num_samples, total_loss = self.run_epoch(optimizer, criterion,
current_global_model, mu)
for name, param in self.model.named_parameters():
if name in mu:
param.data = mu[name]
# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}
def run_epoch(self, optimizer, criterion, current_global_model, mu):
running_loss = 0.0
num_samples = 0
# for inputs, targets in self.trainloader:
for inputs, targets in self.data['train']:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# Descent Step
optimizer.zero_grad()
outputs = self.model(inputs)
ce_loss = criterion(outputs, targets)
loss = ce_loss + self._thermal * prox_term(self.model,
current_global_model)
loss.backward()
optimizer.step()
# add noise for langevin dynamics
add_noise(
self.model,
math.sqrt(self.optim_config.lr) *
self.local_entropy_config.eps)
# acc local updates
moving_avg(mu, self.model.state_dict(),
self.local_entropy_config.alpha)
with torch.no_grad():
running_loss += targets.shape[0] * ce_loss.item()
num_samples += targets.shape[0]
self._thermal *= self.local_entropy_config.inc_factor
return num_samples, running_loss
def evaluate(self, target_data_split_name='test'):
if target_data_split_name != 'test':
return {}
with torch.no_grad():
criterion = torch.nn.CrossEntropyLoss().to(self.device)
self.model.to(self.device)
self.model.eval()
total_loss = num_samples = num_corrects = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = criterion(pred, y)
cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))
# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]
num_corrects += cor.item()
# _hook_on_fit_end
return {
f'{target_data_split_name}_acc': float(num_corrects) /
float(num_samples),
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}
def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)
def get_model_para(self):
return self.model.cpu().state_dict()