FS-TFP/federatedscope/contrib/data/mini_graph_dt.py

165 lines
6.1 KiB
Python

import os
import numpy as np
from federatedscope.register import register_data
# Run with mini_graph_dt:
# python federatedscope/main.py --cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg.yaml --client_cfg \
# federatedscope/gfl/baseline/mini_graph_dc/fedavg_per_client.yaml
# Test Accuracy: ~0.7
def load_mini_graph_dt(config, client_cfgs=None):
import torch
from torch_geometric.data import InMemoryDataset, Data
from torch_geometric.datasets import TUDataset, MoleculeNet
from federatedscope.core.splitters.graph.scaffold_lda_splitter import \
GenFeatures
from federatedscope.core.data import DummyDataTranslator
class MiniGraphDCDataset(InMemoryDataset):
NAME = 'mini_graph_dt'
DATA_NAME = ['BACE', 'BBBP', 'CLINTOX', 'ENZYMES', 'PROTEINS_full']
IN_MEMORY_DATA = {}
def __init__(self, root, splits=[0.8, 0.1, 0.1]):
self.root = root
self.splits = splits
super(MiniGraphDCDataset, self).__init__(root)
@property
def processed_dir(self):
return os.path.join(self.root, self.NAME, 'processed')
@property
def processed_file_names(self):
return ['pre_transform.pt', 'pre_filter.pt']
def __len__(self):
return len(self.DATA_NAME)
def __getitem__(self, idx):
if idx not in self.IN_MEMORY_DATA:
self.IN_MEMORY_DATA[idx] = {}
for split in ['train', 'val', 'test']:
split_data = self._load(idx, split)
if split_data:
self.IN_MEMORY_DATA[idx][split] = split_data
return self.IN_MEMORY_DATA[idx]
def _load(self, idx, split):
try:
data = torch.load(
os.path.join(self.processed_dir, str(idx), f'{split}.pt'))
except:
data = None
return data
def process(self):
np.random.seed(0)
for idx, name in enumerate(self.DATA_NAME):
if name in ['BACE', 'BBBP', 'CLINTOX']:
dataset = MoleculeNet(self.root, name)
featurizer = GenFeatures()
ds = []
for graph in dataset:
graph = featurizer(graph)
ds.append(
Data(edge_index=graph.edge_index,
x=graph.x,
y=graph.y))
dataset = ds
if name in ['BACE', 'BBBP']:
for i in range(len(dataset)):
dataset[i].y = dataset[i].y.long()
if name in ['CLINTOX']:
for i in range(len(dataset)):
dataset[i].y = torch.argmax(
dataset[i].y).view(-1).unsqueeze(0)
else:
# Classification
dataset = TUDataset(self.root, name)
dataset = [
Data(edge_index=graph.edge_index, x=graph.x, y=graph.y)
for graph in dataset
]
# We fix train/val/test
index = np.random.permutation(np.arange(len(dataset)))
train_idx = index[:int(len(dataset) * self.splits[0])]
valid_idx = index[int(len(dataset) * self.splits[0]):int(
len(dataset) * sum(self.splits[:2]))]
test_idx = index[int(len(dataset) * sum(self.splits[:2])):]
if not os.path.isdir(os.path.join(self.processed_dir,
str(idx))):
os.makedirs(os.path.join(self.processed_dir, str(idx)))
train_path = os.path.join(self.processed_dir, str(idx),
'train.pt')
valid_path = os.path.join(self.processed_dir, str(idx),
'val.pt')
test_path = os.path.join(self.processed_dir, str(idx),
'test.pt')
torch.save([dataset[i] for i in train_idx], train_path)
torch.save([dataset[i] for i in valid_idx], valid_path)
torch.save([dataset[i] for i in test_idx], test_path)
print(name, len(dataset), dataset[0])
def meta_info(self):
return {
'BACE': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 1513,
},
'BBBP': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 2039,
},
'CLINTOX': {
'task': 'classification',
'input_dim': 74,
'output_dim': 2,
'num_samples': 1478,
},
'ENZYMES': {
'task': 'classification',
'input_dim': 3,
'output_dim': 6,
'num_samples': 600,
},
'PROTEINS_full': {
'task': 'classification',
'input_dim': 3,
'output_dim': 2,
'num_samples': 1113,
},
}
dataset = MiniGraphDCDataset(config.data.root)
# Convert to dict
datadict = {
client_id + 1: dataset[client_id]
for client_id in range(len(dataset))
}
config.merge_from_list(['federate.client_num', len(dataset)])
translator = DummyDataTranslator(config, client_cfgs)
return translator(datadict), config
def call_mini_graph_dt(config, client_cfgs):
if config.data.type == "mini-graph-dc":
data, modified_config = load_mini_graph_dt(config, client_cfgs)
return data, modified_config
register_data("mini-graph-dc", call_mini_graph_dt)