Ensured sparse mm for readability, logging sparsely as well
This commit is contained in:
parent
2e1836df40
commit
9454fd91a2
|
|
@ -133,7 +133,6 @@ class DCGRUCell(torch.nn.Module):
|
||||||
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||||
inputs_and_state = torch.cat([inputs, state], dim=2)
|
inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||||
input_size = inputs_and_state.size(2)
|
input_size = inputs_and_state.size(2)
|
||||||
dtype = inputs.dtype
|
|
||||||
|
|
||||||
x = inputs_and_state
|
x = inputs_and_state
|
||||||
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||||
|
|
@ -148,7 +147,7 @@ class DCGRUCell(torch.nn.Module):
|
||||||
x = self._concat(x, x1)
|
x = self._concat(x, x1)
|
||||||
|
|
||||||
for k in range(2, self._max_diffusion_step + 1):
|
for k in range(2, self._max_diffusion_step + 1):
|
||||||
x2 = 2 * torch.mm(support, x1) - x0
|
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||||
x = self._concat(x, x2)
|
x = self._concat(x, x2)
|
||||||
x1, x0 = x2, x1
|
x1, x0 = x2, x1
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -149,7 +149,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
:return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim)
|
||||||
"""
|
"""
|
||||||
encoder_hidden_state = self.encoder(inputs)
|
encoder_hidden_state = self.encoder(inputs)
|
||||||
self._logger.info("Encoder complete, starting decoder")
|
self._logger.debug("Encoder complete, starting decoder")
|
||||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||||
self._logger.info("Decoder complete")
|
self._logger.debug("Decoder complete")
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
|
|
@ -143,7 +143,9 @@ class DCRNNSupervisor:
|
||||||
|
|
||||||
output = self.dcrnn_model(x, y, batches_seen)
|
output = self.dcrnn_model(x, y, batches_seen)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
self._logger.info(loss.item())
|
|
||||||
|
self._logger.debug(loss.item())
|
||||||
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
batches_seen += 1
|
batches_seen += 1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue