Added docstrings

This commit is contained in:
Chintan Shah 2019-10-06 17:12:06 -04:00
parent ad8ac8ff2f
commit a8814d5d93
1 changed files with 4 additions and 12 deletions

View File

@ -67,6 +67,7 @@ class DCGRUCell(torch.nn.Module):
supports.append(utils.calculate_scaled_laplacian(adj_mx))
for support in supports:
self._supports.append(self._build_sparse_matrix(support))
self._fc_params = LayerParams(self, 'fc')
self._gconv_params = LayerParams(self, 'gconv')
@ -81,12 +82,11 @@ class DCGRUCell(torch.nn.Module):
def forward(self, inputs, hx):
"""Gated recurrent unit (GRU) with Graph Convolution.
:param input: (B, num_nodes * input_dim)
:param inputs: (B, num_nodes * input_dim)
:param hx: (B, num_nodes * rnn_units)
:return
- Output: A `2-D` tensor with shape `[batch_size x self.output_size]`.
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
the arity and shapes of `state`
- Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`.
"""
output_size = 2 * self._num_units
if self._use_gc_for_ru:
@ -124,14 +124,6 @@ class DCGRUCell(torch.nn.Module):
return value
def _gconv(self, inputs, state, output_size, bias_start=0.0):
"""Graph convolution between input and the graph matrix.
:param args: a 2D Tensor or a list of 2D, batch x n, Tensors.
:param output_size:
:param bias:
:param bias_start:
:return:
"""
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
batch_size = inputs.shape[0]
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))