improved saving and restoring of model

This commit is contained in:
Chintan Shah 2019-10-07 11:56:14 -04:00
parent a5a1063160
commit 3d93008a3e
1 changed files with 20 additions and 12 deletions

View File

@ -45,6 +45,10 @@ class DCRNNSupervisor:
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
self._logger.info("Model created")
self._epoch_num = self._train_kwargs.get('epoch', 0)
if self._epoch_num > 0:
self.load_model()
@staticmethod
def _get_log_dir(kwargs):
log_dir = kwargs['train'].get('log_dir')
@ -74,22 +78,21 @@ class DCRNNSupervisor:
return log_dir
def save_model(self, epoch):
if not os.path.exists(self._log_dir + 'models/'):
os.makedirs(self._log_dir + 'models/')
if not os.path.exists('models/'):
os.makedirs('models/')
config = dict(self._kwargs)
config['model_state_dict'] = self.dcrnn_model.state_dict()
config['epoch'] = epoch
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
torch.save(config, 'models/epo%d.tar' % epoch)
self._logger.info("Saved model at {}".format(epoch))
return self._log_dir + 'models/epo%d.tar' % epoch
return 'models/epo%d.tar' % epoch
def load_model(self, epoch):
assert os.path.exists(
self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch
checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu')
def load_model(self):
assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num
checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu')
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
self._logger.info("Loaded model at {}".format(epoch))
self._logger.info("Loaded model at {}".format(self._epoch_num))
def train(self, **kwargs):
kwargs.update(self._train_kwargs)
@ -125,7 +128,6 @@ class DCRNNSupervisor:
# steps is used in learning rate - will see if need to use it?
min_val_loss = float('inf')
wait = 0
batches_seen = 0
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
@ -134,8 +136,14 @@ class DCRNNSupervisor:
self.dcrnn_model = self.dcrnn_model.train()
self._logger.info('Start training ...')
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
for epoch_num in range(epochs):
# this will fail if model is loaded with a changed batch_size
num_batches = self._data['train_loader'].num_batch
self._logger.info("num_batches:{}".format(num_batches))
batches_seen = num_batches * self._epoch_num
for epoch_num in range(self._epoch_num, epochs):
train_iterator = self._data['train_loader'].get_iterator()
losses = []