FS-TFP/federatedscope/gfl/dataset/utils.py

54 lines
1.8 KiB
Python

import torch
from torch_geometric.utils import to_networkx
def index_to_mask(index, size, device='cpu'):
mask = torch.zeros(size, dtype=torch.bool, device=device)
mask[index] = 1
return mask
def random_planetoid_splits(data,
num_classes,
percls_trn=20,
val_lb=500,
Flag=0):
indices = []
for i in range(num_classes):
index = (data.y == i).nonzero().view(-1)
index = index[torch.randperm(index.size(0))]
indices.append(index)
train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)
if Flag == 0:
rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
data.test_mask = index_to_mask(rest_index[val_lb:],
size=data.num_nodes)
else:
val_index = torch.cat(
[i[percls_trn:percls_trn + val_lb] for i in indices], dim=0)
rest_index = torch.cat([i[percls_trn + val_lb:] for i in indices],
dim=0)
rest_index = rest_index[torch.randperm(rest_index.size(0))]
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
return data
def get_maxDegree(graphs):
maxdegree = 0
for i, graph in enumerate(graphs):
g = to_networkx(graph, to_undirected=True)
gdegree = max(dict(g.degree).values())
if gdegree > maxdegree:
maxdegree = gdegree
return maxdegree