Implementing load and save models and early stopping
This commit is contained in:
parent
d9f41172dc
commit
ba880b8230
|
|
@ -67,6 +67,24 @@ class DCRNNSupervisor:
|
|||
os.makedirs(log_dir)
|
||||
return log_dir
|
||||
|
||||
def save_model(self, epoch):
|
||||
if not os.path.exists(self._log_dir + 'models/'):
|
||||
os.makedirs(self._log_dir + 'models/')
|
||||
|
||||
config = dict(self._kwargs)
|
||||
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
||||
config['epoch'] = epoch
|
||||
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
||||
self._logger.info("Loaded model at {}".format(epoch))
|
||||
return self._log_dir + 'models/epo%d.tar' % epoch
|
||||
|
||||
def load_model(self, epoch):
|
||||
assert os.path.exists(
|
||||
self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch
|
||||
checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu')
|
||||
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
||||
self._logger.info("Loaded model at {}".format(epoch))
|
||||
|
||||
def train(self, **kwargs):
|
||||
kwargs.update(self._train_kwargs)
|
||||
return self._train(**kwargs)
|
||||
|
|
@ -97,12 +115,14 @@ class DCRNNSupervisor:
|
|||
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
|
||||
test_every_n_epochs=10, **kwargs):
|
||||
# steps is used in learning rate - will see if need to use it?
|
||||
min_val_loss = float('inf')
|
||||
wait = 0
|
||||
batches_seen = 0
|
||||
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
||||
criterion = torch.nn.L1Loss() # mae loss
|
||||
|
||||
self.dcrnn_model = self.dcrnn_model.train()
|
||||
|
||||
batches_seen = 0
|
||||
self._logger.info('Start training ...')
|
||||
for epoch_num in range(epochs):
|
||||
train_iterator = self._data['train_loader'].get_iterator()
|
||||
|
|
@ -145,6 +165,20 @@ class DCRNNSupervisor:
|
|||
np.mean(losses), test_loss, (end_time - start_time))
|
||||
self._logger.info(message)
|
||||
|
||||
if val_loss < min_val_loss:
|
||||
wait = 0
|
||||
min_val_loss = val_loss
|
||||
if save_model:
|
||||
model_file_name = self.save_model(epoch_num)
|
||||
self._logger.info(
|
||||
'Val loss decrease from {:.4f} to {:.4f}, saving to {}'.format(min_val_loss, val_loss,
|
||||
model_file_name))
|
||||
elif val_loss >= min_val_loss:
|
||||
wait += 1
|
||||
if wait == patience:
|
||||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||
break
|
||||
|
||||
def _get_x_y(self, x, y):
|
||||
"""
|
||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||
|
|
|
|||
Loading…
Reference in New Issue