59 lines
1.7 KiB
Python
59 lines
1.7 KiB
Python
'''The implementation of ASAM and SAM are borrowed from
|
|
https://github.com/debcaldarola/fedsam
|
|
Caldarola, D., Caputo, B., & Ciccone, M.
|
|
Improving Generalization in Federated Learning by Seeking Flat Minima,
|
|
European Conference on Computer Vision (ECCV) 2022.
|
|
'''
|
|
import os
|
|
import re
|
|
from typing import Callable
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
from PIL import Image
|
|
|
|
from federatedscope.register import register_model
|
|
|
|
|
|
class Conv2Model(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(Conv2Model, self).__init__()
|
|
self.num_classes = num_classes
|
|
|
|
self.layer1 = nn.Sequential(
|
|
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5),
|
|
nn.ReLU(), nn.MaxPool2d(kernel_size=2))
|
|
|
|
self.layer2 = nn.Sequential(
|
|
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5),
|
|
nn.ReLU(), nn.MaxPool2d(kernel_size=2))
|
|
|
|
self.classifier = nn.Sequential(nn.Linear(64 * 5 * 5, 384), nn.ReLU(),
|
|
nn.Linear(384, 192), nn.ReLU(),
|
|
nn.Linear(192, self.num_classes))
|
|
|
|
self.size = self.model_size()
|
|
|
|
def forward(self, x):
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = torch.reshape(x, (x.shape[0], -1))
|
|
x = self.classifier(x)
|
|
return x
|
|
|
|
def model_size(self):
|
|
tot_size = 0
|
|
for param in self.parameters():
|
|
tot_size += param.size()[0]
|
|
return tot_size
|
|
|
|
|
|
def call_fedsam_conv2(model_config, local_data):
|
|
if model_config.type == 'fedsam_conv2':
|
|
model = Conv2Model(10)
|
|
return model
|
|
|
|
|
|
register_model('fedsam_conv2', call_fedsam_conv2)
|