diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index 7550501..b790730 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -74,9 +74,10 @@ class DCGRUCell(torch.nn.Module): def _build_sparse_matrix(L): L = L.tocoo() indices = np.column_stack((L.row, L.col)) + # this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L) + indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))] L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) return L - # return torch.sparse.sparse_reorder(L) def forward(self, inputs, hx): """Gated recurrent unit (GRU) with Graph Convolution.