diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index c2cb7df..e80fdad 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -67,6 +67,7 @@ class DCGRUCell(torch.nn.Module): supports.append(utils.calculate_scaled_laplacian(adj_mx)) for support in supports: self._supports.append(self._build_sparse_matrix(support)) + self._fc_params = LayerParams(self, 'fc') self._gconv_params = LayerParams(self, 'gconv') @@ -81,12 +82,11 @@ class DCGRUCell(torch.nn.Module): def forward(self, inputs, hx): """Gated recurrent unit (GRU) with Graph Convolution. - :param input: (B, num_nodes * input_dim) + :param inputs: (B, num_nodes * input_dim) + :param hx: (B, num_nodes * rnn_units) :return - - Output: A `2-D` tensor with shape `[batch_size x self.output_size]`. - - New state: Either a single `2-D` tensor, or a tuple of tensors matching - the arity and shapes of `state` + - Output: A `2-D` tensor with shape `(B, num_nodes * rnn_units)`. """ output_size = 2 * self._num_units if self._use_gc_for_ru: @@ -124,14 +124,6 @@ class DCGRUCell(torch.nn.Module): return value def _gconv(self, inputs, state, output_size, bias_start=0.0): - """Graph convolution between input and the graph matrix. - - :param args: a 2D Tensor or a list of 2D, batch x n, Tensors. - :param output_size: - :param bias: - :param bias_start: - :return: - """ # Reshape input and state to (batch_size, num_nodes, input_dim/state_dim) batch_size = inputs.shape[0] inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))