From de42a67391974aa0af25b9dbccfade480f36082e Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Mon, 7 Oct 2019 07:59:41 -0400 Subject: [PATCH] added logging statement --- model/pytorch/dcrnn_model.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 9390c6e..560f7ca 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -7,6 +7,10 @@ from model.pytorch.dcrnn_cell import DCGRUCell device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + class Seq2SeqAttrs: def __init__(self, adj_mx, **model_kwargs): self.adj_mx = adj_mx @@ -156,4 +160,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): self._logger.debug("Encoder complete, starting decoder") outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) self._logger.debug("Decoder complete") + if batches_seen == 0: + self._logger.info( + "Total trainable parameters {}".format(count_parameters(self)) + ) return outputs