54 lines
1.8 KiB
Python
54 lines
1.8 KiB
Python
import torch
|
|
from torch_geometric.utils import to_networkx
|
|
|
|
|
|
def index_to_mask(index, size, device='cpu'):
|
|
mask = torch.zeros(size, dtype=torch.bool, device=device)
|
|
mask[index] = 1
|
|
return mask
|
|
|
|
|
|
def random_planetoid_splits(data,
|
|
num_classes,
|
|
percls_trn=20,
|
|
val_lb=500,
|
|
Flag=0):
|
|
|
|
indices = []
|
|
for i in range(num_classes):
|
|
index = (data.y == i).nonzero().view(-1)
|
|
index = index[torch.randperm(index.size(0))]
|
|
indices.append(index)
|
|
|
|
train_index = torch.cat([i[:percls_trn] for i in indices], dim=0)
|
|
|
|
if Flag == 0:
|
|
rest_index = torch.cat([i[percls_trn:] for i in indices], dim=0)
|
|
rest_index = rest_index[torch.randperm(rest_index.size(0))]
|
|
|
|
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
|
|
data.val_mask = index_to_mask(rest_index[:val_lb], size=data.num_nodes)
|
|
data.test_mask = index_to_mask(rest_index[val_lb:],
|
|
size=data.num_nodes)
|
|
else:
|
|
val_index = torch.cat(
|
|
[i[percls_trn:percls_trn + val_lb] for i in indices], dim=0)
|
|
rest_index = torch.cat([i[percls_trn + val_lb:] for i in indices],
|
|
dim=0)
|
|
rest_index = rest_index[torch.randperm(rest_index.size(0))]
|
|
|
|
data.train_mask = index_to_mask(train_index, size=data.num_nodes)
|
|
data.val_mask = index_to_mask(val_index, size=data.num_nodes)
|
|
data.test_mask = index_to_mask(rest_index, size=data.num_nodes)
|
|
return data
|
|
|
|
|
|
def get_maxDegree(graphs):
|
|
maxdegree = 0
|
|
for i, graph in enumerate(graphs):
|
|
g = to_networkx(graph, to_undirected=True)
|
|
gdegree = max(dict(g.degree).values())
|
|
if gdegree > maxdegree:
|
|
maxdegree = gdegree
|
|
return maxdegree
|