Ensured sparse mm for readability, logging sparsely as well
This commit is contained in:
parent
2e1836df40
commit
9454fd91a2
|
|
@ -133,7 +133,6 @@ class DCGRUCell(torch.nn.Module):
|
|||
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||
inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||
input_size = inputs_and_state.size(2)
|
||||
dtype = inputs.dtype
|
||||
|
||||
x = inputs_and_state
|
||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||
|
|
@ -148,7 +147,7 @@ class DCGRUCell(torch.nn.Module):
|
|||
x = self._concat(x, x1)
|
||||
|
||||
for k in range(2, self._max_diffusion_step + 1):
|
||||
x2 = 2 * torch.mm(support, x1) - x0
|
||||
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||
x = self._concat(x, x2)
|
||||
x1, x0 = x2, x1
|
||||
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||
"""
|
||||
encoder_hidden_state = self.encoder(inputs)
|
||||
self._logger.info("Encoder complete, starting decoder")
|
||||
self._logger.debug("Encoder complete, starting decoder")
|
||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||
self._logger.info("Decoder complete")
|
||||
self._logger.debug("Decoder complete")
|
||||
return outputs
|
||||
|
|
|
|||
|
|
@ -143,7 +143,9 @@ class DCRNNSupervisor:
|
|||
|
||||
output = self.dcrnn_model(x, y, batches_seen)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
self._logger.info(loss.item())
|
||||
|
||||
self._logger.debug(loss.item())
|
||||
|
||||
losses.append(loss.item())
|
||||
|
||||
batches_seen += 1
|
||||
|
|
|
|||
Loading…
Reference in New Issue