Ensured sparse mm for readability, logging sparsely as well

This commit is contained in:
Chintan Shah 2019-10-06 13:44:55 -04:00
parent 2e1836df40
commit 9454fd91a2
3 changed files with 6 additions and 5 deletions

View File

@ -133,7 +133,6 @@ class DCGRUCell(torch.nn.Module):
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs_and_state.size(2)
dtype = inputs.dtype
x = inputs_and_state
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
@ -148,7 +147,7 @@ class DCGRUCell(torch.nn.Module):
x = self._concat(x, x1)
for k in range(2, self._max_diffusion_step + 1):
x2 = 2 * torch.mm(support, x1) - x0
x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1

View File

@ -149,7 +149,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
"""
encoder_hidden_state = self.encoder(inputs)
self._logger.info("Encoder complete, starting decoder")
self._logger.debug("Encoder complete, starting decoder")
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
self._logger.info("Decoder complete")
self._logger.debug("Decoder complete")
return outputs

View File

@ -143,7 +143,9 @@ class DCRNNSupervisor:
output = self.dcrnn_model(x, y, batches_seen)
loss = self._compute_loss(y, output, criterion)
self._logger.info(loss.item())
self._logger.debug(loss.item())
losses.append(loss.item())
batches_seen += 1