import torch import torch.nn.functional as F import numpy as np from torch.nn import Parameter from torch.nn import Linear from torch_geometric.data import Data from torch_geometric.nn.conv.gcn_conv import gcn_norm from torch_geometric.nn import MessagePassing, APPNP class GPR_prop(MessagePassing): ''' propagation class for GPR_GNN source: https://github.com/jianhao2016/GPRGNN/blob/master/src/GNN_models.py ''' def __init__(self, K, alpha, Init, Gamma=None, bias=True, **kwargs): super(GPR_prop, self).__init__(aggr='add', **kwargs) self.K = K self.Init = Init self.alpha = alpha assert Init in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] if Init == 'SGC': # SGC-like, note that in this case, alpha has to be a integer. # It means where the peak at when initializing GPR weights. TEMP = 0.0 * np.ones(K + 1) TEMP[alpha] = 1.0 elif Init == 'PPR': # PPR-like TEMP = alpha * (1 - alpha)**np.arange(K + 1) TEMP[-1] = (1 - alpha)**K elif Init == 'NPPR': # Negative PPR TEMP = (alpha)**np.arange(K + 1) TEMP = TEMP / np.sum(np.abs(TEMP)) elif Init == 'Random': # Random bound = np.sqrt(3 / (K + 1)) TEMP = np.random.uniform(-bound, bound, K + 1) TEMP = TEMP / np.sum(np.abs(TEMP)) elif Init == 'WS': # Specify Gamma TEMP = Gamma self.temp = Parameter(torch.tensor(TEMP)) def reset_parameters(self): torch.nn.init.zeros_(self.temp) for k in range(self.K + 1): self.temp.data[k] = self.alpha * (1 - self.alpha)**k self.temp.data[-1] = (1 - self.alpha)**self.K def forward(self, x, edge_index, edge_weight=None): edge_index, norm = gcn_norm(edge_index, edge_weight, num_nodes=x.size(0), dtype=x.dtype) hidden = x * (self.temp[0]) for k in range(self.K): x = self.propagate(edge_index, x=x, norm=norm) gamma = self.temp[k + 1] hidden = hidden + gamma * x return hidden def message(self, x_j, norm): return norm.view(-1, 1) * x_j def __repr__(self): return '{}(K={}, temp={})'.format(self.__class__.__name__, self.K, self.temp) class GPR_Net(torch.nn.Module): r"""GPR-GNN model from the "Adaptive Universal Generalized PageRank Graph Neural Network" paper, in ICLR'21 Arguments: in_channels (int): dimension of input. out_channels (int): dimension of output. hidden (int): dimension of hidden units, default=64. K (int): power of GPR-GNN, default=10. dropout (float): dropout ratio, default=.0. ppnp (str): propagation method in ['PPNP', 'GPR_prop'] Init (str): init method in ['SGC', 'PPR', 'NPPR', 'Random', 'WS'] """ def __init__(self, in_channels, out_channels, hidden=64, K=10, dropout=.0, ppnp='GPR_prop', alpha=0.1, Init='PPR', Gamma=None): super(GPR_Net, self).__init__() self.lin1 = Linear(in_channels, hidden) self.lin2 = Linear(hidden, out_channels) if ppnp == 'PPNP': self.prop1 = APPNP(K, alpha) elif ppnp == 'GPR_prop': self.prop1 = GPR_prop(K, alpha, Init, Gamma) self.Init = Init self.dprate = 0.5 self.dropout = dropout def reset_parameters(self): self.prop1.reset_parameters() def forward(self, data): if isinstance(data, Data): x, edge_index = data.x, data.edge_index elif isinstance(data, tuple): x, edge_index = data else: raise TypeError('Unsupported data type!') x = F.dropout(x, p=self.dropout, training=self.training) x = F.relu(self.lin1(x)) x = F.dropout(x, p=self.dropout, training=self.training) x = self.lin2(x) if self.dprate == 0.0: x = self.prop1(x, edge_index) return F.log_softmax(x, dim=1) else: x = F.dropout(x, p=self.dprate, training=self.training) x = self.prop1(x, edge_index) return F.log_softmax(x, dim=1)