improved saving and restoring of model
This commit is contained in:
parent
a5a1063160
commit
3d93008a3e
|
|
@ -45,6 +45,10 @@ class DCRNNSupervisor:
|
||||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||||
self._logger.info("Model created")
|
self._logger.info("Model created")
|
||||||
|
|
||||||
|
self._epoch_num = self._train_kwargs.get('epoch', 0)
|
||||||
|
if self._epoch_num > 0:
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_log_dir(kwargs):
|
def _get_log_dir(kwargs):
|
||||||
log_dir = kwargs['train'].get('log_dir')
|
log_dir = kwargs['train'].get('log_dir')
|
||||||
|
|
@ -74,22 +78,21 @@ class DCRNNSupervisor:
|
||||||
return log_dir
|
return log_dir
|
||||||
|
|
||||||
def save_model(self, epoch):
|
def save_model(self, epoch):
|
||||||
if not os.path.exists(self._log_dir + 'models/'):
|
if not os.path.exists('models/'):
|
||||||
os.makedirs(self._log_dir + 'models/')
|
os.makedirs('models/')
|
||||||
|
|
||||||
config = dict(self._kwargs)
|
config = dict(self._kwargs)
|
||||||
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
config['model_state_dict'] = self.dcrnn_model.state_dict()
|
||||||
config['epoch'] = epoch
|
config['epoch'] = epoch
|
||||||
torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch)
|
torch.save(config, 'models/epo%d.tar' % epoch)
|
||||||
self._logger.info("Saved model at {}".format(epoch))
|
self._logger.info("Saved model at {}".format(epoch))
|
||||||
return self._log_dir + 'models/epo%d.tar' % epoch
|
return 'models/epo%d.tar' % epoch
|
||||||
|
|
||||||
def load_model(self, epoch):
|
def load_model(self):
|
||||||
assert os.path.exists(
|
assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num
|
||||||
self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch
|
checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu')
|
||||||
checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu')
|
|
||||||
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
self._logger.info("Loaded model at {}".format(epoch))
|
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
kwargs.update(self._train_kwargs)
|
kwargs.update(self._train_kwargs)
|
||||||
|
|
@ -125,7 +128,6 @@ class DCRNNSupervisor:
|
||||||
# steps is used in learning rate - will see if need to use it?
|
# steps is used in learning rate - will see if need to use it?
|
||||||
min_val_loss = float('inf')
|
min_val_loss = float('inf')
|
||||||
wait = 0
|
wait = 0
|
||||||
batches_seen = 0
|
|
||||||
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
||||||
|
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
||||||
|
|
@ -134,8 +136,14 @@ class DCRNNSupervisor:
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
||||||
self._logger.info('Start training ...')
|
self._logger.info('Start training ...')
|
||||||
self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch))
|
|
||||||
for epoch_num in range(epochs):
|
# this will fail if model is loaded with a changed batch_size
|
||||||
|
num_batches = self._data['train_loader'].num_batch
|
||||||
|
self._logger.info("num_batches:{}".format(num_batches))
|
||||||
|
|
||||||
|
batches_seen = num_batches * self._epoch_num
|
||||||
|
|
||||||
|
for epoch_num in range(self._epoch_num, epochs):
|
||||||
train_iterator = self._data['train_loader'].get_iterator()
|
train_iterator = self._data['train_loader'].get_iterator()
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue