diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index 2a0b85f..7550501 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -74,7 +74,7 @@ class DCGRUCell(torch.nn.Module): def _build_sparse_matrix(L): L = L.tocoo() indices = np.column_stack((L.row, L.col)) - L = torch.sparse_coo_tensor(indices.T, L.data, L.shape) + L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device) return L # return torch.sparse.sparse_reorder(L)