167 lines
6.8 KiB
Python
167 lines
6.8 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
from __future__ import print_function
|
|
|
|
import numpy as np
|
|
import tensorflow as tf
|
|
|
|
from tensorflow.contrib.rnn import RNNCell
|
|
|
|
from tensorflow.python.platform import tf_logging as logging
|
|
|
|
from lib import dcrnn_utils
|
|
|
|
|
|
class DCGRUCell(RNNCell):
|
|
"""Graph Convolution Gated Recurrent Unit cell.
|
|
"""
|
|
|
|
def call(self, inputs, **kwargs):
|
|
pass
|
|
|
|
def _compute_output_shape(self, input_shape):
|
|
pass
|
|
|
|
def __init__(self, num_units, adj_mx, max_diffusion_step, num_nodes, input_size=None, num_proj=None,
|
|
activation=tf.nn.tanh, reuse=None, filter_type="laplacian"):
|
|
"""
|
|
|
|
:param num_units:
|
|
:param adj_mx:
|
|
:param max_diffusion_step:
|
|
:param num_nodes:
|
|
:param input_size:
|
|
:param num_proj:
|
|
:param activation:
|
|
:param reuse:
|
|
:param filter_type: "laplacian", "random_walk", "dual_random_walk".
|
|
"""
|
|
super(DCGRUCell, self).__init__(_reuse=reuse)
|
|
if input_size is not None:
|
|
logging.warn("%s: The input_size parameter is deprecated.", self)
|
|
self._activation = activation
|
|
self._num_nodes = num_nodes
|
|
self._num_proj = num_proj
|
|
self._num_units = num_units
|
|
self._max_diffusion_step = max_diffusion_step
|
|
self._supports = []
|
|
supports = []
|
|
if filter_type == "laplacian":
|
|
supports.append(dcrnn_utils.calculate_scaled_laplacian(adj_mx, lambda_max=None))
|
|
elif filter_type == "random_walk":
|
|
supports.append(dcrnn_utils.calculate_random_walk_matrix(adj_mx).T)
|
|
elif filter_type == "dual_random_walk":
|
|
supports.append(dcrnn_utils.calculate_random_walk_matrix(adj_mx).T)
|
|
supports.append(dcrnn_utils.calculate_random_walk_matrix(adj_mx.T).T)
|
|
else:
|
|
supports.append(dcrnn_utils.calculate_scaled_laplacian(adj_mx))
|
|
for support in supports:
|
|
self._supports.append(self._build_sparse_matrix(support))
|
|
|
|
@staticmethod
|
|
def _build_sparse_matrix(L):
|
|
L = L.tocoo()
|
|
indices = np.column_stack((L.row, L.col))
|
|
L = tf.SparseTensor(indices, L.data, L.shape)
|
|
return tf.sparse_reorder(L)
|
|
|
|
@property
|
|
def state_size(self):
|
|
return self._num_nodes * self._num_units
|
|
|
|
@property
|
|
def output_size(self):
|
|
output_size = self._num_nodes * self._num_units
|
|
if self._num_proj is not None:
|
|
output_size = self._num_nodes * self._num_proj
|
|
return output_size
|
|
|
|
def __call__(self, inputs, state, scope=None):
|
|
"""Gated recurrent unit (GRU) with Graph Convolution.
|
|
:param inputs: (B, num_nodes * input_dim)
|
|
|
|
:return
|
|
- Output: A `2-D` tensor with shape `[batch_size x self.output_size]`.
|
|
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
|
|
the arity and shapes of `state`
|
|
"""
|
|
with tf.variable_scope(scope or "dcgru_cell"):
|
|
with tf.variable_scope("gates"): # Reset gate and update gate.
|
|
# We start with bias of 1.0 to not reset and not update.
|
|
value = tf.nn.sigmoid(
|
|
self._gconv(inputs, state, 2 * self._num_units, bias_start=1.0, scope=scope))
|
|
r, u = tf.split(value=value, num_or_size_splits=2, axis=1)
|
|
# r, u = sigmoid(r), sigmoid(u)
|
|
with tf.variable_scope("candidate"):
|
|
c = self._gconv(inputs, r * state, self._num_units, scope=scope)
|
|
if self._activation is not None:
|
|
c = self._activation(c)
|
|
output = new_state = u * state + (1 - u) * c
|
|
if self._num_proj is not None:
|
|
with tf.variable_scope("projection"):
|
|
w = tf.get_variable('w', shape=(self._num_units, self._num_proj))
|
|
batch_size = inputs.get_shape()[0].value
|
|
output = tf.reshape(new_state, shape=(-1, self._num_units))
|
|
output = tf.reshape(tf.matmul(output, w), shape=(batch_size, self.output_size))
|
|
return output, new_state
|
|
|
|
@staticmethod
|
|
def _concat(x, x_):
|
|
x_ = tf.expand_dims(x_, 0)
|
|
return tf.concat([x, x_], axis=0)
|
|
|
|
def _gconv(self, inputs, state, output_size, bias_start=0.0, scope=None):
|
|
"""Graph convolution between input and the graph matrix.
|
|
|
|
:param args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
|
:param output_size:
|
|
:param bias:
|
|
:param bias_start:
|
|
:param scope:
|
|
:return:
|
|
"""
|
|
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
|
batch_size = inputs.get_shape()[0].value
|
|
inputs = tf.reshape(inputs, (batch_size, self._num_nodes, -1))
|
|
state = tf.reshape(state, (batch_size, self._num_nodes, -1))
|
|
inputs_and_state = tf.concat([inputs, state], axis=2)
|
|
input_size = inputs_and_state.get_shape()[2].value
|
|
dtype = inputs.dtype
|
|
|
|
x = inputs_and_state
|
|
x0 = tf.transpose(x, perm=[1, 2, 0]) # (num_nodes, total_arg_size, batch_size)
|
|
x0 = tf.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
|
|
x = tf.expand_dims(x0, axis=0)
|
|
|
|
scope = tf.get_variable_scope()
|
|
with tf.variable_scope(scope):
|
|
if self._max_diffusion_step == 0:
|
|
pass
|
|
else:
|
|
for support in self._supports:
|
|
x1 = tf.sparse_tensor_dense_matmul(support, x0)
|
|
x = self._concat(x, x1)
|
|
|
|
for k in range(2, self._max_diffusion_step + 1):
|
|
x2 = 2 * tf.sparse_tensor_dense_matmul(support, x1) - x0
|
|
x = self._concat(x, x2)
|
|
x1, x0 = x2, x1
|
|
|
|
num_matrices = len(self._supports) * self._max_diffusion_step + 1 # Adds for x itself.
|
|
x = tf.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
|
|
x = tf.transpose(x, perm=[3, 1, 2, 0]) # (batch_size, num_nodes, input_size, order)
|
|
x = tf.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
|
|
|
|
weights = tf.get_variable(
|
|
'weights', [input_size * num_matrices, output_size], dtype=dtype,
|
|
initializer=tf.contrib.layers.xavier_initializer())
|
|
x = tf.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
|
|
|
|
biases = tf.get_variable(
|
|
"biases", [output_size],
|
|
dtype=dtype,
|
|
initializer=tf.constant_initializer(bias_start, dtype=dtype))
|
|
x = tf.nn.bias_add(x, biases)
|
|
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
|
|
return tf.reshape(x, [batch_size, self._num_nodes * output_size])
|