Implemented lr annealing schedule
This commit is contained in:
parent
ba880b8230
commit
8d3b1d0d66
|
|
@ -119,6 +119,9 @@ class DCRNNSupervisor:
|
||||||
wait = 0
|
wait = 0
|
||||||
batches_seen = 0
|
batches_seen = 0
|
||||||
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
||||||
|
|
||||||
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
||||||
|
gamma=lr_decay_ratio)
|
||||||
criterion = torch.nn.L1Loss() # mae loss
|
criterion = torch.nn.L1Loss() # mae loss
|
||||||
|
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
@ -149,20 +152,23 @@ class DCRNNSupervisor:
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
lr_scheduler.step()
|
||||||
|
|
||||||
val_loss = self.evaluate(dataset='val')
|
val_loss = self.evaluate(dataset='val')
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
if epoch_num % log_every == 0:
|
if epoch_num % log_every == 0:
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} ' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f}, lr: {:.6f}' \
|
||||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), val_loss,
|
np.mean(losses), val_loss, lr_scheduler.get_lr(),
|
||||||
(end_time - start_time))
|
(end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
if epoch_num % test_every_n_epochs == 0:
|
if epoch_num % test_every_n_epochs == 0:
|
||||||
test_loss = self.evaluate(dataset='test')
|
test_loss = self.evaluate(dataset='test')
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f} ' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f} ' \
|
||||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), test_loss, (end_time - start_time))
|
np.mean(losses), test_loss, lr_scheduler.get_lr(),
|
||||||
|
(end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
if val_loss < min_val_loss:
|
if val_loss < min_val_loss:
|
||||||
|
|
@ -171,8 +177,8 @@ class DCRNNSupervisor:
|
||||||
if save_model:
|
if save_model:
|
||||||
model_file_name = self.save_model(epoch_num)
|
model_file_name = self.save_model(epoch_num)
|
||||||
self._logger.info(
|
self._logger.info(
|
||||||
'Val loss decrease from {:.4f} to {:.4f}, saving to {}'.format(min_val_loss, val_loss,
|
'Val loss decrease from {:.4f} to {:.4f}, '
|
||||||
model_file_name))
|
'saving to {}'.format(min_val_loss, val_loss, model_file_name))
|
||||||
elif val_loss >= min_val_loss:
|
elif val_loss >= min_val_loss:
|
||||||
wait += 1
|
wait += 1
|
||||||
if wait == patience:
|
if wait == patience:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue