added logging statement

This commit is contained in:
Chintan Shah 2019-10-07 07:59:41 -04:00
parent 941675d6a7
commit de42a67391
1 changed files with 8 additions and 0 deletions

View File

@ -7,6 +7,10 @@ from model.pytorch.dcrnn_cell import DCGRUCell
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def count_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Seq2SeqAttrs:
def __init__(self, adj_mx, **model_kwargs):
self.adj_mx = adj_mx
@ -156,4 +160,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
self._logger.debug("Encoder complete, starting decoder")
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
self._logger.debug("Decoder complete")
if batches_seen == 0:
self._logger.info(
"Total trainable parameters {}".format(count_parameters(self))
)
return outputs