'''The implementation of ASAM and SAM are borrowed from https://github.com/debcaldarola/fedsam Caldarola, D., Caputo, B., & Ciccone, M. Improving Generalization in Federated Learning by Seeking Flat Minima, European Conference on Computer Vision (ECCV) 2022. ''' import re import json import numpy as np from federatedscope.register import register_splitter from federatedscope.core.splitters import BaseSplitter class FedSAM_CIFAR10_Splitter(BaseSplitter): """ This splitter split according to what FedSAM provides Args: client_num: the dataset will be split into ``client_num`` pieces alpha (float): Partition hyperparameter in LDA, smaller alpha \ generates more extreme heterogeneous scenario see \ ``np.random.dirichlet`` """ def __init__(self, client_num, alpha=0.5): self.alpha = alpha super(FedSAM_CIFAR10_Splitter, self).__init__(client_num) def __call__(self, dataset, prior=None, **kwargs): dataset = [ds for ds in dataset] label = np.array([y for x, y in dataset]) alpha_str = f'{self.alpha:.2f}' if len(label) == 50000: filename = \ 'data/fedsam_cifar10/data/{}/federated_{}_alpha_{' \ '}.json'.format('train', 'train', alpha_str) elif len(label) == 10000: filename = 'data/fedsam_cifar10/data/test/test.json' with open(filename, 'r') as ips: content = json.load(ips) idx_slice = [] def get_idx(name_list): return [ int(re.findall('img_\d+_label', fn)[0][4:-6]) for fn in name_list ] if len(label) == 50000: for uid in range(self.client_num): idx_slice.append( get_idx(content['user_data'][str(uid)]['x'])) elif len(label) == 10000: idx_slice.append(get_idx(content['user_data'][str(100)]['x'])) idx_slice = np.array_split(np.array(idx_slice[0]), self.client_num) data_list = [[dataset[idx] for idx in idxs] for idxs in idx_slice] return data_list def call_fedsam_cifar10_splitter(splitter_type, client_num, **kwargs): if splitter_type == 'fedsam_cifar10_splitter': splitter = FedSAM_CIFAR10_Splitter(client_num, **kwargs) return splitter register_splitter('fedsam_cifar10_splitter', call_fedsam_cifar10_splitter)