165 lines
6.1 KiB
Python
165 lines
6.1 KiB
Python
import os
|
|
import numpy as np
|
|
|
|
from federatedscope.register import register_data
|
|
|
|
# Run with mini_graph_dt:
|
|
# python federatedscope/main.py --cfg \
|
|
# federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml --client_cfg \
|
|
# federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml
|
|
# Test Accuracy: ~0.7
|
|
|
|
|
|
def load_mini_graph_dt(config, client_cfgs=None):
|
|
import torch
|
|
from torch_geometric.data import InMemoryDataset, Data
|
|
from torch_geometric.datasets import TUDataset, MoleculeNet
|
|
from federatedscope.core.splitters.graph.scaffold_lda_splitter import \
|
|
GenFeatures
|
|
from federatedscope.core.data import DummyDataTranslator
|
|
|
|
class MiniGraphDCDataset(InMemoryDataset):
|
|
NAME = 'mini_graph_dt'
|
|
DATA_NAME = ['BACE', 'BBBP', 'CLINTOX', 'ENZYMES', 'PROTEINS_full']
|
|
IN_MEMORY_DATA = {}
|
|
|
|
def __init__(self, root, splits=[0.8, 0.1, 0.1]):
|
|
self.root = root
|
|
self.splits = splits
|
|
super(MiniGraphDCDataset, self).__init__(root)
|
|
|
|
@property
|
|
def processed_dir(self):
|
|
return os.path.join(self.root, self.NAME, 'processed')
|
|
|
|
@property
|
|
def processed_file_names(self):
|
|
return ['pre_transform.pt', 'pre_filter.pt']
|
|
|
|
def __len__(self):
|
|
return len(self.DATA_NAME)
|
|
|
|
def __getitem__(self, idx):
|
|
if idx not in self.IN_MEMORY_DATA:
|
|
self.IN_MEMORY_DATA[idx] = {}
|
|
for split in ['train', 'val', 'test']:
|
|
split_data = self._load(idx, split)
|
|
if split_data:
|
|
self.IN_MEMORY_DATA[idx][split] = split_data
|
|
return self.IN_MEMORY_DATA[idx]
|
|
|
|
def _load(self, idx, split):
|
|
try:
|
|
data = torch.load(
|
|
os.path.join(self.processed_dir, str(idx), f'{split}.pt'))
|
|
except:
|
|
data = None
|
|
return data
|
|
|
|
def process(self):
|
|
np.random.seed(0)
|
|
for idx, name in enumerate(self.DATA_NAME):
|
|
if name in ['BACE', 'BBBP', 'CLINTOX']:
|
|
dataset = MoleculeNet(self.root, name)
|
|
featurizer = GenFeatures()
|
|
ds = []
|
|
for graph in dataset:
|
|
graph = featurizer(graph)
|
|
ds.append(
|
|
Data(edge_index=graph.edge_index,
|
|
x=graph.x,
|
|
y=graph.y))
|
|
dataset = ds
|
|
if name in ['BACE', 'BBBP']:
|
|
for i in range(len(dataset)):
|
|
dataset[i].y = dataset[i].y.long()
|
|
if name in ['CLINTOX']:
|
|
for i in range(len(dataset)):
|
|
dataset[i].y = torch.argmax(
|
|
dataset[i].y).view(-1).unsqueeze(0)
|
|
else:
|
|
# Classification
|
|
dataset = TUDataset(self.root, name)
|
|
dataset = [
|
|
Data(edge_index=graph.edge_index, x=graph.x, y=graph.y)
|
|
for graph in dataset
|
|
]
|
|
|
|
# We fix train/val/test
|
|
index = np.random.permutation(np.arange(len(dataset)))
|
|
train_idx = index[:int(len(dataset) * self.splits[0])]
|
|
valid_idx = index[int(len(dataset) * self.splits[0]):int(
|
|
len(dataset) * sum(self.splits[:2]))]
|
|
test_idx = index[int(len(dataset) * sum(self.splits[:2])):]
|
|
|
|
if not os.path.isdir(os.path.join(self.processed_dir,
|
|
str(idx))):
|
|
os.makedirs(os.path.join(self.processed_dir, str(idx)))
|
|
|
|
train_path = os.path.join(self.processed_dir, str(idx),
|
|
'train.pt')
|
|
valid_path = os.path.join(self.processed_dir, str(idx),
|
|
'val.pt')
|
|
test_path = os.path.join(self.processed_dir, str(idx),
|
|
'test.pt')
|
|
|
|
torch.save([dataset[i] for i in train_idx], train_path)
|
|
torch.save([dataset[i] for i in valid_idx], valid_path)
|
|
torch.save([dataset[i] for i in test_idx], test_path)
|
|
|
|
print(name, len(dataset), dataset[0])
|
|
|
|
def meta_info(self):
|
|
return {
|
|
'BACE': {
|
|
'task': 'classification',
|
|
'input_dim': 74,
|
|
'output_dim': 2,
|
|
'num_samples': 1513,
|
|
},
|
|
'BBBP': {
|
|
'task': 'classification',
|
|
'input_dim': 74,
|
|
'output_dim': 2,
|
|
'num_samples': 2039,
|
|
},
|
|
'CLINTOX': {
|
|
'task': 'classification',
|
|
'input_dim': 74,
|
|
'output_dim': 2,
|
|
'num_samples': 1478,
|
|
},
|
|
'ENZYMES': {
|
|
'task': 'classification',
|
|
'input_dim': 3,
|
|
'output_dim': 6,
|
|
'num_samples': 600,
|
|
},
|
|
'PROTEINS_full': {
|
|
'task': 'classification',
|
|
'input_dim': 3,
|
|
'output_dim': 2,
|
|
'num_samples': 1113,
|
|
},
|
|
}
|
|
|
|
dataset = MiniGraphDCDataset(config.data.root)
|
|
# Convert to dict
|
|
datadict = {
|
|
client_id + 1: dataset[client_id]
|
|
for client_id in range(len(dataset))
|
|
}
|
|
config.merge_from_list(['federate.client_num', len(dataset)])
|
|
translator = DummyDataTranslator(config, client_cfgs)
|
|
|
|
return translator(datadict), config
|
|
|
|
|
|
def call_mini_graph_dt(config, client_cfgs):
|
|
if config.data.type == "mini-graph-dc":
|
|
data, modified_config = load_mini_graph_dt(config, client_cfgs)
|
|
return data, modified_config
|
|
|
|
|
|
register_data("mini-graph-dc", call_mini_graph_dt)
|