from federatedscope.register import register_splitter from federatedscope.core.splitters import BaseSplitter class TrafficSplitter(BaseSplitter): def __init__(self, client_num, **kwargs): super(TrafficSplitter, self).__init__(client_num, **kwargs) def __call__(self, dataset, *args, **kwargs): """ 后面考虑子图标记划分 Args: dataset: ndarray(timestep, num_node, channel) *args: **kwargs: Returns: [ndarray(timestep, per_node, channel) * client_nums] """ pass def call_my_splitter(splitter_type, client_num, **kwargs): if splitter_type == 'trafficflow': splitter = TrafficSplitter(client_num, **kwargs) return splitter register_splitter('trafficflow', call_my_splitter)