'''The implementation of ASAM and SAM are borrowed from https://github.com/debcaldarola/fedsam Caldarola, D., Caputo, B., & Ciccone, M. Improving Generalization in Federated Learning by Seeking Flat Minima, European Conference on Computer Vision (ECCV) 2022. ''' from collections import defaultdict import torch from federatedscope.core.trainers import BaseTrainer from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer class ASAM(object): def __init__(self, optimizer, model, rho=0.5, eta=0.01): self.optimizer = optimizer self.model = model self.rho = rho self.eta = eta self.state = defaultdict(dict) @torch.no_grad() def ascent_step(self): wgrads = [] for n, p in self.model.named_parameters(): if p.grad is None: continue t_w = self.state[p].get("eps") if t_w is None: t_w = torch.clone(p).detach() self.state[p]["eps"] = t_w if 'weight' in n: t_w[...] = p[...] t_w.abs_().add_(self.eta) p.grad.mul_(t_w) wgrads.append(torch.norm(p.grad, p=2)) wgrad_norm = torch.norm(torch.stack(wgrads), p=2) + 1.e-16 for n, p in self.model.named_parameters(): if p.grad is None: continue t_w = self.state[p].get("eps") if 'weight' in n: p.grad.mul_(t_w) eps = t_w eps[...] = p.grad[...] eps.mul_(self.rho / wgrad_norm) p.add_(eps) self.optimizer.zero_grad() @torch.no_grad() def descent_step(self): for n, p in self.model.named_parameters(): if p.grad is None: continue p.sub_(self.state[p]["eps"]) self.optimizer.step() self.optimizer.zero_grad() class SAM(ASAM): @torch.no_grad() def ascent_step(self): grads = [] for n, p in self.model.named_parameters(): if p.grad is None: continue grads.append(torch.norm(p.grad, p=2)) grad_norm = torch.norm(torch.stack(grads), p=2) + 1.e-16 for n, p in self.model.named_parameters(): if p.grad is None: continue eps = self.state[p].get("eps") if eps is None: eps = torch.clone(p).detach() self.state[p]["eps"] = eps eps[...] = p.grad[...] eps.mul_(self.rho / grad_norm) p.add_(eps) self.optimizer.zero_grad() class SAMTrainer(BaseTrainer): def __init__(self, model, data, device, **kwargs): # NN modules self.model = model # FS `ClientData` or your own data self.data = data # Device name self.device = device # configs self.kwargs = kwargs self.config = kwargs['config'] self.optim_config = self.config.train.optimizer self.sam_config = self.config.trainer.sam def train(self): # Criterion & Optimizer criterion = torch.nn.CrossEntropyLoss().to(self.device) optimizer = get_optimizer(self.model, **self.optim_config) # _hook_on_fit_start_init self.model.to(self.device) self.model.train() num_samples, total_loss = self.run_epoch(optimizer, criterion) # _hook_on_fit_end return num_samples, self.model.cpu().state_dict(), \ {'loss_total': total_loss, 'avg_loss': total_loss/float( num_samples)} def run_epoch(self, optimizer, criterion): if self.sam_config.adaptive: minimizer = ASAM(optimizer, self.model, rho=self.sam_config.rho, eta=self.sam_config.eta) else: minimizer = SAM(optimizer, self.model, rho=self.sam_config.rho, eta=self.sam_config.eta) running_loss = 0.0 num_samples = 0 # for inputs, targets in self.trainloader: for inputs, targets in self.data['train']: inputs = inputs.to(self.device) targets = targets.to(self.device) # Ascent Step outputs = self.model(inputs) loss = criterion(outputs, targets) loss.backward() minimizer.ascent_step() # Descent Step criterion(self.model(inputs), targets).backward() minimizer.descent_step() with torch.no_grad(): running_loss += targets.shape[0] * loss.item() num_samples += targets.shape[0] return num_samples, running_loss def evaluate(self, target_data_split_name='test'): if target_data_split_name != 'test': return {} with torch.no_grad(): criterion = torch.nn.CrossEntropyLoss().to(self.device) self.model.to(self.device) self.model.eval() total_loss = num_samples = num_corrects = 0 # _hook_on_batch_start_init for x, y in self.data[target_data_split_name]: # _hook_on_batch_forward x, y = x.to(self.device), y.to(self.device) pred = self.model(x) loss = criterion(pred, y) cor = torch.sum(torch.argmax(pred, dim=-1).eq(y)) # _hook_on_batch_end total_loss += loss.item() * y.shape[0] num_samples += y.shape[0] num_corrects += cor.item() # _hook_on_fit_end return { f'{target_data_split_name}_acc': float(num_corrects) / float(num_samples), f'{target_data_split_name}_loss': total_loss, f'{target_data_split_name}_total': num_samples, f'{target_data_split_name}_avg_loss': total_loss / float(num_samples) } def update(self, model_parameters, strict=False): self.model.load_state_dict(model_parameters, strict) def get_model_para(self): return self.model.cpu().state_dict()