diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 6a8462a..c92dad3 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -163,15 +163,15 @@ class DCRNNSupervisor: if epoch_num % log_every == 0: message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \ '{:.1f}s'.format(epoch_num, epochs, batches_seen, - np.mean(losses), val_loss, lr_scheduler.get_lr(), + np.mean(losses), val_loss, lr_scheduler.get_lr()[0], (end_time - start_time)) self._logger.info(message) if epoch_num % test_every_n_epochs == 0: test_loss = self.evaluate(dataset='test') - message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \ + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \ '{:.1f}s'.format(epoch_num, epochs, batches_seen, - np.mean(losses), test_loss, lr_scheduler.get_lr(), + np.mean(losses), test_loss, lr_scheduler.get_lr()[0], (end_time - start_time)) self._logger.info(message)