added logging statement
This commit is contained in:
parent
941675d6a7
commit
de42a67391
|
|
@ -7,6 +7,10 @@ from model.pytorch.dcrnn_cell import DCGRUCell
|
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class Seq2SeqAttrs:
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
self.adj_mx = adj_mx
|
||||
|
|
@ -156,4 +160,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs):
|
|||
self._logger.debug("Encoder complete, starting decoder")
|
||||
outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen)
|
||||
self._logger.debug("Decoder complete")
|
||||
if batches_seen == 0:
|
||||
self._logger.info(
|
||||
"Total trainable parameters {}".format(count_parameters(self))
|
||||
)
|
||||
return outputs
|
||||
|
|
|
|||
Loading…
Reference in New Issue