Added docstrings
This commit is contained in:
parent
ad8ac8ff2f
commit
a8814d5d93
|
|
@ -67,6 +67,7 @@ class DCGRUCell(torch.nn.Module):
|
|||
supports.append(utils.calculate_scaled_laplacian(adj_mx))
|
||||
for support in supports:
|
||||
self._supports.append(self._build_sparse_matrix(support))
|
||||
|
||||
self._fc_params = LayerParams(self, 'fc')
|
||||
self._gconv_params = LayerParams(self, 'gconv')
|
||||
|
||||
|
|
@ -81,12 +82,11 @@ class DCGRUCell(torch.nn.Module):
|
|||
|
||||
def forward(self, inputs, hx):
|
||||
"""Gated recurrent unit (GRU) with Graph Convolution.
|
||||
:param input: (B, num_nodes * input_dim)
|
||||
:param inputs: (B, num_nodes * input_dim)
|
||||
:param hx: (B, num_nodes * rnn_units)
|
||||
|
||||
:return
|
||||
- Output: A `2-D` tensor with shape `[batch_size x self.output_size]`.
|
||||
- New state: Either a single `2-D` tensor, or a tuple of tensors matching
|
||||
the arity and shapes of `state`
|
||||
- Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`.
|
||||
"""
|
||||
output_size = 2 * self._num_units
|
||||
if self._use_gc_for_ru:
|
||||
|
|
@ -124,14 +124,6 @@ class DCGRUCell(torch.nn.Module):
|
|||
return value
|
||||
|
||||
def _gconv(self, inputs, state, output_size, bias_start=0.0):
|
||||
"""Graph convolution between input and the graph matrix.
|
||||
|
||||
:param args: a 2D Tensor or a list of 2D, batch x n, Tensors.
|
||||
:param output_size:
|
||||
:param bias:
|
||||
:param bias_start:
|
||||
:return:
|
||||
"""
|
||||
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
||||
batch_size = inputs.shape[0]
|
||||
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||
|
|
|
|||
Loading…
Reference in New Issue