diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index c92dad3..a44e30d 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -161,7 +161,7 @@ class DCRNNSupervisor: val_loss = self.evaluate(dataset='val') end_time = time.time() if epoch_num % log_every == 0: - message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \ + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}, ' \ '{:.1f}s'.format(epoch_num, epochs, batches_seen, np.mean(losses), val_loss, lr_scheduler.get_lr()[0], (end_time - start_time)) @@ -177,12 +177,13 @@ class DCRNNSupervisor: if val_loss < min_val_loss: wait = 0 - min_val_loss = val_loss if save_model: model_file_name = self.save_model(epoch_num) self._logger.info( 'Val loss decrease from {:.4f} to {:.4f}, ' 'saving to {}'.format(min_val_loss, val_loss, model_file_name)) + min_val_loss = val_loss + elif val_loss >= min_val_loss: wait += 1 if wait == patience: