From 9454fd91a2991e599829854efd03617a12da5d44 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Sun, 6 Oct 2019 13:44:55 -0400 Subject: [PATCH] Ensured sparse mm for readability, logging sparsely as well --- model/pytorch/dcrnn_cell.py | 3 +-- model/pytorch/dcrnn_model.py | 4 ++-- model/pytorch/dcrnn_supervisor.py | 4 +++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index 1bbdf20..0779a72 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -133,7 +133,6 @@ class DCGRUCell(torch.nn.Module): state = torch.reshape(state, (batch_size, self._num_nodes, -1)) inputs_and_state = torch.cat([inputs, state], dim=2) input_size = inputs_and_state.size(2) - dtype = inputs.dtype x = inputs_and_state x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size) @@ -148,7 +147,7 @@ class DCGRUCell(torch.nn.Module): x = self._concat(x, x1) for k in range(2, self._max_diffusion_step + 1): - x2 = 2 * torch.mm(support, x1) - x0 + x2 = 2 * torch.sparse.mm(support, x1) - x0 x = self._concat(x, x2) x1, x0 = x2, x1 diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 36196a3..4a14334 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -149,7 +149,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) """ encoder_hidden_state = self.encoder(inputs) - self._logger.info("Encoder complete, starting decoder") + self._logger.debug("Encoder complete, starting decoder") outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) - self._logger.info("Decoder complete") + self._logger.debug("Decoder complete") return outputs diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index cadcf24..67c8440 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -143,7 +143,9 @@ class DCRNNSupervisor: output = self.dcrnn_model(x, y, batches_seen) loss = self._compute_loss(y, output, criterion) - self._logger.info(loss.item()) + + self._logger.debug(loss.item()) + losses.append(loss.item()) batches_seen += 1