FS-TFP/federatedscope/trafficflow/splitters/trafficSplitter.py

33 lines
822 B
Python

from federatedscope.register import register_splitter
from federatedscope.core.splitters import BaseSplitter
class TrafficSplitter(BaseSplitter):
def __init__(self, client_num, **kwargs):
super(TrafficSplitter, self).__init__(client_num, **kwargs)
def __call__(self, dataset, *args, **kwargs):
"""
后面考虑子图标记划分
Args:
dataset: ndarray(timestep, num_node, channel)
*args:
**kwargs:
Returns:
[ndarray(timestep, per_node, channel) * client_nums]
"""
pass
def call_my_splitter(splitter_type, client_num, **kwargs):
if splitter_type == 'trafficflow':
splitter = TrafficSplitter(client_num, **kwargs)
return splitter
register_splitter('trafficflow', call_my_splitter)