added logging statement
This commit is contained in:
parent
941675d6a7
commit
de42a67391
|
|
@ -7,6 +7,10 @@ from model.pytorch.dcrnn_cell import DCGRUCell
|
||||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
class Seq2SeqAttrs:
|
class Seq2SeqAttrs:
|
||||||
def __init__(self, adj_mx, **model_kwargs):
|
def __init__(self, adj_mx, **model_kwargs):
|
||||||
self.adj_mx = adj_mx
|
self.adj_mx = adj_mx
|
||||||
|
|
@ -156,4 +160,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
||||||
self._logger.debug("Encoder complete, starting decoder")
|
self._logger.debug("Encoder complete, starting decoder")
|
||||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||||
self._logger.debug("Decoder complete")
|
self._logger.debug("Decoder complete")
|
||||||
|
if batches_seen == 0:
|
||||||
|
self._logger.info(
|
||||||
|
"Total trainable parameters {}".format(count_parameters(self))
|
||||||
|
)
|
||||||
return outputs
|
return outputs
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue