Implementing load and save models and early stopping
This commit is contained in:
parent
d9f41172dc
commit
ba880b8230
|
|
@ -67,6 +67,24 @@ class DCRNNSupervisor:
|
||||||
os.makedirs(log_dir)
|
os.makedirs(log_dir)
|
||||||
return log_dir
|
return log_dir
|
||||||
|
|
||||||
|
def save_model(self, epoch):
|
||||||
|
if not os.path.exists(self._log_dir + 'models/'):
|
||||||
|
os.makedirs(self._log_dir + 'models/')
|
||||||
|
|
||||||
|
config = dict(self._kwargs)
|
||||||
|
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
||||||
|
config['epoch'] = epoch
|
||||||
|
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
||||||
|
self._logger.info("Loaded model at {}".format(epoch))
|
||||||
|
return self._log_dir + 'models/epo%d.tar' % epoch
|
||||||
|
|
||||||
|
def load_model(self, epoch):
|
||||||
|
assert os.path.exists(
|
||||||
|
self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch
|
||||||
|
checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu')
|
||||||
|
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
|
self._logger.info("Loaded model at {}".format(epoch))
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
kwargs.update(self._train_kwargs)
|
kwargs.update(self._train_kwargs)
|
||||||
return self._train(**kwargs)
|
return self._train(**kwargs)
|
||||||
|
|
@ -97,12 +115,14 @@ class DCRNNSupervisor:
|
||||||
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
|
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
|
||||||
test_every_n_epochs=10, **kwargs):
|
test_every_n_epochs=10, **kwargs):
|
||||||
# steps is used in learning rate - will see if need to use it?
|
# steps is used in learning rate - will see if need to use it?
|
||||||
|
min_val_loss = float('inf')
|
||||||
|
wait = 0
|
||||||
|
batches_seen = 0
|
||||||
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
||||||
criterion = torch.nn.L1Loss() # mae loss
|
criterion = torch.nn.L1Loss() # mae loss
|
||||||
|
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
||||||
batches_seen = 0
|
|
||||||
self._logger.info('Start training ...')
|
self._logger.info('Start training ...')
|
||||||
for epoch_num in range(epochs):
|
for epoch_num in range(epochs):
|
||||||
train_iterator = self._data['train_loader'].get_iterator()
|
train_iterator = self._data['train_loader'].get_iterator()
|
||||||
|
|
@ -145,6 +165,20 @@ class DCRNNSupervisor:
|
||||||
np.mean(losses), test_loss, (end_time - start_time))
|
np.mean(losses), test_loss, (end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
|
if val_loss < min_val_loss:
|
||||||
|
wait = 0
|
||||||
|
min_val_loss = val_loss
|
||||||
|
if save_model:
|
||||||
|
model_file_name = self.save_model(epoch_num)
|
||||||
|
self._logger.info(
|
||||||
|
'Val loss decrease from {:.4f} to {:.4f}, saving to {}'.format(min_val_loss, val_loss,
|
||||||
|
model_file_name))
|
||||||
|
elif val_loss >= min_val_loss:
|
||||||
|
wait += 1
|
||||||
|
if wait == patience:
|
||||||
|
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||||
|
break
|
||||||
|
|
||||||
def _get_x_y(self, x, y):
|
def _get_x_y(self, x, y):
|
||||||
"""
|
"""
|
||||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue