92 lines
3.5 KiB
Python
92 lines
3.5 KiB
Python
from federatedscope.core.data.utils import filter_dict
|
|
|
|
try:
|
|
import torch
|
|
from torch.utils.data import Dataset
|
|
except ImportError:
|
|
torch = None
|
|
Dataset = object
|
|
|
|
|
|
def get_dataloader(dataset, config, split='train'):
|
|
"""
|
|
Instantiate a DataLoader via config.
|
|
|
|
Args:
|
|
dataset: dataset from which to load the data.
|
|
config: configs containing batch_size, shuffle, etc.
|
|
split: current split (default: ``train``), if split is ``test``, \
|
|
``cfg.dataloader.shuffle`` will be ``False``. And in PyG, ``test`` \
|
|
split will use ``NeighborSampler`` by default.
|
|
|
|
Returns:
|
|
Instance of specific ``DataLoader`` configured by config.
|
|
|
|
Note:
|
|
The key-value pairs of ``dataloader.type`` and ``DataLoader``:
|
|
======================== ===============================
|
|
``dataloader.type`` Source
|
|
======================== ===============================
|
|
``raw`` No DataLoader
|
|
``base`` ``torch.utils.data.DataLoader``
|
|
``pyg`` ``torch_geometric.loader.DataLoader``
|
|
``graphsaint-rw`` \
|
|
``torch_geometric.loader.GraphSAINTRandomWalkSampler``
|
|
``neighbor`` ``torch_geometric.loader.NeighborSampler``
|
|
``mf`` ``federatedscope.mf.dataloader.MFDataLoader``
|
|
======================== ===============================
|
|
"""
|
|
# DataLoader builder only support torch backend now.
|
|
if config.backend != 'torch':
|
|
return None
|
|
|
|
if config.dataloader.type == 'base':
|
|
from torch.utils.data import DataLoader
|
|
loader_cls = DataLoader
|
|
elif config.dataloader.type == 'raw':
|
|
# No DataLoader
|
|
return dataset
|
|
elif config.dataloader.type == 'pyg':
|
|
from torch_geometric.loader import DataLoader as PyGDataLoader
|
|
loader_cls = PyGDataLoader
|
|
elif config.dataloader.type == 'graphsaint-rw':
|
|
if split == 'train':
|
|
from torch_geometric.loader import GraphSAINTRandomWalkSampler
|
|
loader_cls = GraphSAINTRandomWalkSampler
|
|
else:
|
|
from torch_geometric.loader import NeighborSampler
|
|
loader_cls = NeighborSampler
|
|
elif config.dataloader.type == 'neighbor':
|
|
from torch_geometric.loader import NeighborSampler
|
|
loader_cls = NeighborSampler
|
|
elif config.dataloader.type == 'mf':
|
|
from federatedscope.mf.dataloader import MFDataLoader
|
|
loader_cls = MFDataLoader
|
|
elif config.dataloader.type == 'trafficflow':
|
|
# This if is not strictly necessary but helps avoid potential bugs
|
|
from torch.utils.data import DataLoader
|
|
loader_cls = DataLoader
|
|
else:
|
|
raise ValueError(f'data.loader.type {config.data.loader.type} '
|
|
f'not found!')
|
|
|
|
raw_args = dict(config.dataloader)
|
|
if split != 'train':
|
|
raw_args['shuffle'] = False
|
|
raw_args['sizes'] = [-1]
|
|
raw_args['drop_last'] = False
|
|
# For evaluation in GFL
|
|
if config.dataloader.type in ['graphsaint-rw', 'neighbor']:
|
|
raw_args['batch_size'] = 4096
|
|
dataset = dataset[0].edge_index
|
|
else:
|
|
if config.dataloader.type in ['graphsaint-rw']:
|
|
# Raw graph
|
|
dataset = dataset[0]
|
|
elif config.dataloader.type in ['neighbor']:
|
|
# edge_index of raw graph
|
|
dataset = dataset[0].edge_index
|
|
filtered_args = filter_dict(loader_cls.__init__, raw_args)
|
|
dataloader = loader_cls(dataset, **filtered_args)
|
|
return dataloader
|