FS-TFP/federatedscope/contrib/model/fedsam_convnet.py

59 lines
1.7 KiB
Python

'''The implementation of ASAM and SAM are borrowed from
https://github.com/debcaldarola/fedsam
Caldarola, D., Caputo, B., & Ciccone, M.
Improving Generalization in Federated Learning by Seeking Flat Minima,
European Conference on Computer Vision (ECCV) 2022.
'''
import os
import re
from typing import Callable
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from federatedscope.register import register_model
class Conv2Model(nn.Module):
def __init__(self, num_classes):
super(Conv2Model, self).__init__()
self.num_classes = num_classes
self.layer1 = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=5),
nn.ReLU(), nn.MaxPool2d(kernel_size=2))
self.layer2 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=5),
nn.ReLU(), nn.MaxPool2d(kernel_size=2))
self.classifier = nn.Sequential(nn.Linear(64 * 5 * 5, 384), nn.ReLU(),
nn.Linear(384, 192), nn.ReLU(),
nn.Linear(192, self.num_classes))
self.size = self.model_size()
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = torch.reshape(x, (x.shape[0], -1))
x = self.classifier(x)
return x
def model_size(self):
tot_size = 0
for param in self.parameters():
tot_size += param.size()[0]
return tot_size
def call_fedsam_conv2(model_config, local_data):
if model_config.type == 'fedsam_conv2':
model = Conv2Model(10)
return model
register_model('fedsam_conv2', call_fedsam_conv2)