33 lines
815 B
Python
33 lines
815 B
Python
|
|
from federatedscope.register import register_splitter
|
|
from federatedscope.core.splitters import BaseSplitter
|
|
|
|
|
|
class TrafficSplitter(BaseSplitter):
|
|
def __init__(self, client_num, **kwargs):
|
|
super(TrafficSplitter, self).__init__(client_num, **kwargs)
|
|
|
|
def __call__(self, dataset, *args, **kwargs):
|
|
"""
|
|
TODO:subgraph partition
|
|
|
|
Args:
|
|
dataset: ndarray(timestep, num_node, channel)
|
|
*args:
|
|
**kwargs:
|
|
|
|
Returns:
|
|
[ndarray(timestep, per_node, channel) * client_nums]
|
|
|
|
"""
|
|
pass
|
|
|
|
|
|
def call_my_splitter(splitter_type, client_num, **kwargs):
|
|
if splitter_type == 'trafficflow':
|
|
splitter = TrafficSplitter(client_num, **kwargs)
|
|
return splitter
|
|
|
|
|
|
register_splitter('trafficflow', call_my_splitter)
|