60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from torch_geometric.data import Data
|
|
from torch_geometric.data.batch import Batch
|
|
|
|
from torch.nn import GRU, Linear, ReLU, Sequential
|
|
from torch_geometric.nn import NNConv, Set2Set
|
|
|
|
|
|
class MPNNs2s(nn.Module):
|
|
r"""MPNN from "Neural Message Passing for Quantum Chemistry" for
|
|
regression and classification on graphs.
|
|
Source: https://github.com/pyg-team/pytorch_geometric/blob/master
|
|
/examples/qm9_nn_conv.py
|
|
|
|
Arguments:
|
|
in_channels (int): Size for the input node features.
|
|
out_channels (int): dimension of output.
|
|
num_nn (int): num_edge_features.
|
|
hidden (int): Size for the output node representations. Default to 64.
|
|
|
|
"""
|
|
def __init__(self, in_channels, out_channels, num_nn, hidden=64):
|
|
super(MPNNs2s, self).__init__()
|
|
self.lin0 = torch.nn.Linear(in_channels, hidden)
|
|
|
|
nn = Sequential(Linear(num_nn, 16), ReLU(),
|
|
Linear(16, hidden * hidden))
|
|
self.conv = NNConv(hidden, hidden, nn, aggr='add')
|
|
self.gru = GRU(hidden, hidden)
|
|
|
|
self.set2set = Set2Set(hidden, processing_steps=3, num_layers=3)
|
|
self.lin1 = torch.nn.Linear(2 * hidden, hidden)
|
|
self.lin2 = torch.nn.Linear(hidden, out_channels)
|
|
|
|
def forward(self, data):
|
|
if isinstance(data, Batch):
|
|
x, edge_index, edge_attr, batch = data.x, data.edge_index, \
|
|
data.edge_attr, data.batch
|
|
elif isinstance(data, tuple):
|
|
x, edge_index, edge_attr, batch = data
|
|
else:
|
|
raise TypeError('Unsupported data type!')
|
|
|
|
self.gru.flatten_parameters()
|
|
out = F.relu(self.lin0(x.float()))
|
|
h = out.unsqueeze(0)
|
|
|
|
for i in range(3):
|
|
m = F.relu(self.conv(out, edge_index, edge_attr.float()))
|
|
out, h = self.gru(m.unsqueeze(0), h)
|
|
out = out.squeeze(0)
|
|
|
|
out = self.set2set(out, batch)
|
|
out = F.relu(self.lin1(out))
|
|
out = self.lin2(out)
|
|
return out
|