FS-TFP/federatedscope/contrib/splitter/example.py

27 lines
871 B
Python

from federatedscope.register import register_splitter
from federatedscope.core.splitters import BaseSplitter
class MySplitter(BaseSplitter):
def __init__(self, client_num, **kwargs):
super(MySplitter, self).__init__(client_num, **kwargs)
def __call__(self, dataset, *args, **kwargs):
# Dummy splitter, only for demonstration
per_samples = len(dataset) // self.client_num
data_list, cur_index = [], 0
for i in range(self.client_num):
data_list.append(
[x for x in range(cur_index, cur_index + per_samples)])
cur_index += per_samples
return data_list
def call_my_splitter(splitter_type, client_num, **kwargs):
if splitter_type == 'mysplitter':
splitter = MySplitter(client_num, **kwargs)
return splitter
register_splitter('mysplitter', call_my_splitter)