from federatedscope.register import register_splitter from federatedscope.core.splitters import BaseSplitter class MySplitter(BaseSplitter): def __init__(self, client_num, **kwargs): super(MySplitter, self).__init__(client_num, **kwargs) def __call__(self, dataset, *args, **kwargs): # Dummy splitter, only for demonstration per_samples = len(dataset) // self.client_num data_list, cur_index = [], 0 for i in range(self.client_num): data_list.append( [x for x in range(cur_index, cur_index + per_samples)]) cur_index += per_samples return data_list def call_my_splitter(splitter_type, client_num, **kwargs): if splitter_type == 'mysplitter': splitter = MySplitter(client_num, **kwargs) return splitter register_splitter('mysplitter', call_my_splitter)