Implementing load and save models and early stopping

This commit is contained in:
Chintan Shah 2019-10-04 17:25:03 -04:00
parent d9f41172dc
commit ba880b8230
1 changed files with 35 additions and 1 deletions

View File

@ -67,6 +67,24 @@ class DCRNNSupervisor:
os.makedirs(log_dir)
return log_dir
def save_model(self, epoch):
if not os.path.exists(self._log_dir + 'models/'):
os.makedirs(self._log_dir + 'models/')
config = dict(self._kwargs)
config['model_state_dict'] = self.dcrnn_model.state_dict()
config['epoch'] = epoch
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
self._logger.info("Loaded model at {}".format(epoch))
return self._log_dir + 'models/epo%d.tar' % epoch
def load_model(self, epoch):
assert os.path.exists(
self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch
checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu')
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
self._logger.info("Loaded model at {}".format(epoch))
def train(self, **kwargs):
kwargs.update(self._train_kwargs)
return self._train(**kwargs)
@ -97,12 +115,14 @@ class DCRNNSupervisor:
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
test_every_n_epochs=10, **kwargs):
# steps is used in learning rate - will see if need to use it?
min_val_loss = float('inf')
wait = 0
batches_seen = 0
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
criterion = torch.nn.L1Loss() # mae loss
self.dcrnn_model = self.dcrnn_model.train()
batches_seen = 0
self._logger.info('Start training ...')
for epoch_num in range(epochs):
train_iterator = self._data['train_loader'].get_iterator()
@ -145,6 +165,20 @@ class DCRNNSupervisor:
np.mean(losses), test_loss, (end_time - start_time))
self._logger.info(message)
if val_loss < min_val_loss:
wait = 0
min_val_loss = val_loss
if save_model:
model_file_name = self.save_model(epoch_num)
self._logger.info(
'Val loss decrease from {:.4f} to {:.4f}, saving to {}'.format(min_val_loss, val_loss,
model_file_name))
elif val_loss >= min_val_loss:
wait += 1
if wait == patience:
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break
def _get_x_y(self, x, y):
"""
:param x: shape (batch_size, seq_len, num_sensor, input_dim)