146 lines
4.8 KiB
Python
146 lines
4.8 KiB
Python
import math
|
|
from collections import defaultdict
|
|
|
|
import torch
|
|
|
|
from federatedscope.core.trainers import BaseTrainer
|
|
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
|
|
|
|
|
|
def copy_params(src):
|
|
tgt = dict()
|
|
for name, t in src.named_parameters():
|
|
if t.requires_grad:
|
|
tgt[name] = t.detach().clone()
|
|
return tgt
|
|
|
|
|
|
def prox_term(cur, last):
|
|
loss = .0
|
|
for name, w in cur.named_parameters():
|
|
loss += 0.5 * torch.sum((w - last[name])**2)
|
|
return loss
|
|
|
|
|
|
def add_noise(model, sigma):
|
|
for p in model.parameters():
|
|
if p.requires_grad:
|
|
p.data += sigma * torch.randn(size=p.shape, device=p.device)
|
|
|
|
|
|
def moving_avg(cur, new, alpha):
|
|
for k, v in cur.items():
|
|
v.data = (1 - alpha) * v + alpha * new[k]
|
|
|
|
|
|
class LocalEntropyTrainer(BaseTrainer):
|
|
def __init__(self, model, data, device, **kwargs):
|
|
# NN modules
|
|
self.model = model
|
|
# FS `ClientData` or your own data
|
|
self.data = data
|
|
# Device name
|
|
self.device = device
|
|
# configs
|
|
self.kwargs = kwargs
|
|
self.config = kwargs['config']
|
|
self.optim_config = self.config.train.optimizer
|
|
self.local_entropy_config = self.config.trainer.local_entropy
|
|
self._thermal = self.local_entropy_config.gamma
|
|
|
|
def train(self):
|
|
# Criterion & Optimizer
|
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
|
optimizer = get_optimizer(self.model, **self.optim_config)
|
|
|
|
# _hook_on_fit_start_init
|
|
self.model.to(self.device)
|
|
current_global_model = copy_params(self.model)
|
|
mu = copy_params(self.model)
|
|
self.model.train()
|
|
|
|
num_samples, total_loss = self.run_epoch(optimizer, criterion,
|
|
current_global_model, mu)
|
|
for name, param in self.model.named_parameters():
|
|
if name in mu:
|
|
param.data = mu[name]
|
|
|
|
# _hook_on_fit_end
|
|
return num_samples, self.model.cpu().state_dict(), \
|
|
{'loss_total': total_loss, 'avg_loss': total_loss/float(
|
|
num_samples)}
|
|
|
|
def run_epoch(self, optimizer, criterion, current_global_model, mu):
|
|
running_loss = 0.0
|
|
num_samples = 0
|
|
# for inputs, targets in self.trainloader:
|
|
for inputs, targets in self.data['train']:
|
|
inputs = inputs.to(self.device)
|
|
targets = targets.to(self.device)
|
|
|
|
# Descent Step
|
|
optimizer.zero_grad()
|
|
outputs = self.model(inputs)
|
|
ce_loss = criterion(outputs, targets)
|
|
loss = ce_loss + self._thermal * prox_term(self.model,
|
|
current_global_model)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# add noise for langevin dynamics
|
|
add_noise(
|
|
self.model,
|
|
math.sqrt(self.optim_config.lr) *
|
|
self.local_entropy_config.eps)
|
|
|
|
# acc local updates
|
|
moving_avg(mu, self.model.state_dict(),
|
|
self.local_entropy_config.alpha)
|
|
|
|
with torch.no_grad():
|
|
running_loss += targets.shape[0] * ce_loss.item()
|
|
|
|
num_samples += targets.shape[0]
|
|
self._thermal *= self.local_entropy_config.inc_factor
|
|
|
|
return num_samples, running_loss
|
|
|
|
def evaluate(self, target_data_split_name='test'):
|
|
if target_data_split_name != 'test':
|
|
return {}
|
|
|
|
with torch.no_grad():
|
|
criterion = torch.nn.CrossEntropyLoss().to(self.device)
|
|
|
|
self.model.to(self.device)
|
|
self.model.eval()
|
|
total_loss = num_samples = num_corrects = 0
|
|
# _hook_on_batch_start_init
|
|
for x, y in self.data[target_data_split_name]:
|
|
# _hook_on_batch_forward
|
|
x, y = x.to(self.device), y.to(self.device)
|
|
pred = self.model(x)
|
|
loss = criterion(pred, y)
|
|
cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))
|
|
|
|
# _hook_on_batch_end
|
|
total_loss += loss.item() * y.shape[0]
|
|
num_samples += y.shape[0]
|
|
num_corrects += cor.item()
|
|
|
|
# _hook_on_fit_end
|
|
return {
|
|
f'{target_data_split_name}_acc': float(num_corrects) /
|
|
float(num_samples),
|
|
f'{target_data_split_name}_loss': total_loss,
|
|
f'{target_data_split_name}_total': num_samples,
|
|
f'{target_data_split_name}_avg_loss': total_loss /
|
|
float(num_samples)
|
|
}
|
|
|
|
def update(self, model_parameters, strict=False):
|
|
self.model.load_state_dict(model_parameters, strict)
|
|
|
|
def get_model_para(self):
|
|
return self.model.cpu().state_dict()
|