FS-TFP/federatedscope/contrib/trainer/sam.py

185 lines
6.1 KiB
Python

'''The implementation of ASAM and SAM are borrowed from
https://github.com/debcaldarola/fedsam
Caldarola, D., Caputo, B., & Ciccone, M.
Improving Generalization in Federated Learning by Seeking Flat Minima,
European Conference on Computer Vision (ECCV) 2022.
'''
from collections import defaultdict
import torch
from federatedscope.core.trainers import BaseTrainer
from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer
class ASAM(object):
def __init__(self, optimizer, model, rho=0.5, eta=0.01):
self.optimizer = optimizer
self.model = model
self.rho = rho
self.eta = eta
self.state = defaultdict(dict)
@torch.no_grad()
def ascent_step(self):
wgrads = []
for n, p in self.model.named_parameters():
if p.grad is None:
continue
t_w = self.state[p].get("eps")
if t_w is None:
t_w = torch.clone(p).detach()
self.state[p]["eps"] = t_w
if 'weight' in n:
t_w[...] = p[...]
t_w.abs_().add_(self.eta)
p.grad.mul_(t_w)
wgrads.append(torch.norm(p.grad, p=2))
wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16
for n, p in self.model.named_parameters():
if p.grad is None:
continue
t_w = self.state[p].get("eps")
if 'weight' in n:
p.grad.mul_(t_w)
eps = t_w
eps[...] = p.grad[...]
eps.mul_(self.rho / wgrad_norm)
p.add_(eps)
self.optimizer.zero_grad()
@torch.no_grad()
def descent_step(self):
for n, p in self.model.named_parameters():
if p.grad is None:
continue
p.sub_(self.state[p]["eps"])
self.optimizer.step()
self.optimizer.zero_grad()
class SAM(ASAM):
@torch.no_grad()
def ascent_step(self):
grads = []
for n, p in self.model.named_parameters():
if p.grad is None:
continue
grads.append(torch.norm(p.grad, p=2))
grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16
for n, p in self.model.named_parameters():
if p.grad is None:
continue
eps = self.state[p].get("eps")
if eps is None:
eps = torch.clone(p).detach()
self.state[p]["eps"] = eps
eps[...] = p.grad[...]
eps.mul_(self.rho / grad_norm)
p.add_(eps)
self.optimizer.zero_grad()
class SAMTrainer(BaseTrainer):
def __init__(self, model, data, device, **kwargs):
# NN modules
self.model = model
# FS `ClientData` or your own data
self.data = data
# Device name
self.device = device
# configs
self.kwargs = kwargs
self.config = kwargs['config']
self.optim_config = self.config.train.optimizer
self.sam_config = self.config.trainer.sam
def train(self):
# Criterion & Optimizer
criterion = torch.nn.CrossEntropyLoss().to(self.device)
optimizer = get_optimizer(self.model, **self.optim_config)
# _hook_on_fit_start_init
self.model.to(self.device)
self.model.train()
num_samples, total_loss = self.run_epoch(optimizer, criterion)
# _hook_on_fit_end
return num_samples, self.model.cpu().state_dict(), \
{'loss_total': total_loss, 'avg_loss': total_loss/float(
num_samples)}
def run_epoch(self, optimizer, criterion):
if self.sam_config.adaptive:
minimizer = ASAM(optimizer,
self.model,
rho=self.sam_config.rho,
eta=self.sam_config.eta)
else:
minimizer = SAM(optimizer,
self.model,
rho=self.sam_config.rho,
eta=self.sam_config.eta)
running_loss = 0.0
num_samples = 0
# for inputs, targets in self.trainloader:
for inputs, targets in self.data['train']:
inputs = inputs.to(self.device)
targets = targets.to(self.device)
# Ascent Step
outputs = self.model(inputs)
loss = criterion(outputs, targets)
loss.backward()
minimizer.ascent_step()
# Descent Step
criterion(self.model(inputs), targets).backward()
minimizer.descent_step()
with torch.no_grad():
running_loss += targets.shape[0] * loss.item()
num_samples += targets.shape[0]
return num_samples, running_loss
def evaluate(self, target_data_split_name='test'):
if target_data_split_name != 'test':
return {}
with torch.no_grad():
criterion = torch.nn.CrossEntropyLoss().to(self.device)
self.model.to(self.device)
self.model.eval()
total_loss = num_samples = num_corrects = 0
# _hook_on_batch_start_init
for x, y in self.data[target_data_split_name]:
# _hook_on_batch_forward
x, y = x.to(self.device), y.to(self.device)
pred = self.model(x)
loss = criterion(pred, y)
cor = torch.sum(torch.argmax(pred, dim=-1).eq(y))
# _hook_on_batch_end
total_loss += loss.item() * y.shape[0]
num_samples += y.shape[0]
num_corrects += cor.item()
# _hook_on_fit_end
return {
f'{target_data_split_name}_acc': float(num_corrects) /
float(num_samples),
f'{target_data_split_name}_loss': total_loss,
f'{target_data_split_name}_total': num_samples,
f'{target_data_split_name}_avg_loss': total_loss /
float(num_samples)
}
def update(self, model_parameters, strict=False):
self.model.load_state_dict(model_parameters, strict)
def get_model_para(self):
return self.model.cpu().state_dict()