diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index cb16561..cb66455 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -45,6 +45,10 @@ class DCRNNSupervisor: self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model self._logger.info("Model created") + self._epoch_num = self._train_kwargs.get('epoch', 0) + if self._epoch_num > 0: + self.load_model() + @staticmethod def _get_log_dir(kwargs): log_dir = kwargs['train'].get('log_dir') @@ -74,22 +78,21 @@ class DCRNNSupervisor: return log_dir def save_model(self, epoch): - if not os.path.exists(self._log_dir + 'models/'): - os.makedirs(self._log_dir + 'models/') + if not os.path.exists('models/'): + os.makedirs('models/') config = dict(self._kwargs) config['model_state_dict'] = self.dcrnn_model.state_dict() config['epoch'] = epoch - torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch) + torch.save(config, 'models/epo%d.tar' % epoch) self._logger.info("Saved model at {}".format(epoch)) - return self._log_dir + 'models/epo%d.tar' % epoch + return 'models/epo%d.tar' % epoch - def load_model(self, epoch): - assert os.path.exists( - self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch - checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu') + def load_model(self): + assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num + checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu') self.dcrnn_model.load_state_dict(checkpoint['model_state_dict']) - self._logger.info("Loaded model at {}".format(epoch)) + self._logger.info("Loaded model at {}".format(self._epoch_num)) def train(self, **kwargs): kwargs.update(self._train_kwargs) @@ -125,7 +128,6 @@ class DCRNNSupervisor: # steps is used in learning rate - will see if need to use it? min_val_loss = float('inf') wait = 0 - batches_seen = 0 optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr) lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps, @@ -134,8 +136,14 @@ class DCRNNSupervisor: self.dcrnn_model = self.dcrnn_model.train() self._logger.info('Start training ...') - self._logger.info("num_batches:{}".format(self._data['train_loader'].num_batch)) - for epoch_num in range(epochs): + + # this will fail if model is loaded with a changed batch_size + num_batches = self._data['train_loader'].num_batch + self._logger.info("num_batches:{}".format(num_batches)) + + batches_seen = num_batches * self._epoch_num + + for epoch_num in range(self._epoch_num, epochs): train_iterator = self._data['train_loader'].get_iterator() losses = []