135 lines
4.4 KiB
Python
135 lines
4.4 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
import numpy as np
|
|
|
|
from torch.nn import Parameter
|
|
from torch.nn import Linear
|
|
from torch_geometric.data import Data
|
|
from torch_geometric.nn.conv.gcn_conv import gcn_norm
|
|
from torch_geometric.nn import MessagePassing, APPNP
|
|
|
|
|
|
class GPR_prop(MessagePassing):
|
|
'''
|
|
propagation class for GPR_GNN
|
|
source: https://github.com/jianhao2016/GPRGNN/blob/master/src/GNN_models.py
|
|
'''
|
|
def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs):
|
|
super(GPR_prop, self).__init__(aggr='add', **kwargs)
|
|
self.K = K
|
|
self.Init = Init
|
|
self.alpha = alpha
|
|
|
|
assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
|
|
if Init == 'SGC':
|
|
# SGC-like, note that in this case, alpha has to be a integer.
|
|
# It means where the peak at when initializing GPR weights.
|
|
TEMP = 0.0 * np.ones(K + 1)
|
|
TEMP[alpha] = 1.0
|
|
elif Init == 'PPR':
|
|
# PPR-like
|
|
TEMP = alpha * (1 - alpha)**np.arange(K + 1)
|
|
TEMP[-1] = (1 - alpha)**K
|
|
elif Init == 'NPPR':
|
|
# Negative PPR
|
|
TEMP = (alpha)**np.arange(K + 1)
|
|
TEMP = TEMP / np.sum(np.abs(TEMP))
|
|
elif Init == 'Random':
|
|
# Random
|
|
bound = np.sqrt(3 / (K + 1))
|
|
TEMP = np.random.uniform(-bound, bound, K + 1)
|
|
TEMP = TEMP / np.sum(np.abs(TEMP))
|
|
elif Init == 'WS':
|
|
# Specify Gamma
|
|
TEMP = Gamma
|
|
|
|
self.temp = Parameter(torch.tensor(TEMP))
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.zeros_(self.temp)
|
|
for k in range(self.K + 1):
|
|
self.temp.data[k] = self.alpha * (1 - self.alpha)**k
|
|
self.temp.data[-1] = (1 - self.alpha)**self.K
|
|
|
|
def forward(self, x, edge_index, edge_weight=None):
|
|
edge_index, norm = gcn_norm(edge_index,
|
|
edge_weight,
|
|
num_nodes=x.size(0),
|
|
dtype=x.dtype)
|
|
|
|
hidden = x * (self.temp[0])
|
|
for k in range(self.K):
|
|
x = self.propagate(edge_index, x=x, norm=norm)
|
|
gamma = self.temp[k + 1]
|
|
hidden = hidden + gamma * x
|
|
return hidden
|
|
|
|
def message(self, x_j, norm):
|
|
return norm.view(-1, 1) * x_j
|
|
|
|
def __repr__(self):
|
|
return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K,
|
|
self.temp)
|
|
|
|
|
|
class GPR_Net(torch.nn.Module):
|
|
r"""GPR-GNN model from the "Adaptive Universal Generalized PageRank
|
|
Graph Neural Network" paper, in ICLR'21
|
|
|
|
Arguments:
|
|
in_channels (int): dimension of input.
|
|
out_channels (int): dimension of output.
|
|
hidden (int): dimension of hidden units, default=64.
|
|
K (int): power of GPR-GNN, default=10.
|
|
dropout (float): dropout ratio, default=.0.
|
|
ppnp (str): propagation method in ['PPNP', 'GPR_prop']
|
|
Init (str): init method in ['SGC', 'PPR', 'NPPR', 'Random', 'WS']
|
|
|
|
"""
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
hidden=64,
|
|
K=10,
|
|
dropout=.0,
|
|
ppnp='GPR_prop',
|
|
alpha=0.1,
|
|
Init='PPR',
|
|
Gamma=None):
|
|
super(GPR_Net, self).__init__()
|
|
self.lin1 = Linear(in_channels, hidden)
|
|
self.lin2 = Linear(hidden, out_channels)
|
|
|
|
if ppnp == 'PPNP':
|
|
self.prop1 = APPNP(K, alpha)
|
|
elif ppnp == 'GPR_prop':
|
|
self.prop1 = GPR_prop(K, alpha, Init, Gamma)
|
|
|
|
self.Init = Init
|
|
self.dprate = 0.5
|
|
self.dropout = dropout
|
|
|
|
def reset_parameters(self):
|
|
self.prop1.reset_parameters()
|
|
|
|
def forward(self, data):
|
|
if isinstance(data, Data):
|
|
x, edge_index = data.x, data.edge_index
|
|
elif isinstance(data, tuple):
|
|
x, edge_index = data
|
|
else:
|
|
raise TypeError('Unsupported data type!')
|
|
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = F.relu(self.lin1(x))
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = self.lin2(x)
|
|
|
|
if self.dprate == 0.0:
|
|
x = self.prop1(x, edge_index)
|
|
return F.log_softmax(x, dim=1)
|
|
else:
|
|
x = F.dropout(x, p=self.dprate, training=self.training)
|
|
x = self.prop1(x, edge_index)
|
|
return F.log_softmax(x, dim=1)
|