diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index f9cf1a4..a2d61c0 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -67,6 +67,24 @@ class DCRNNSupervisor: os.makedirs(log_dir) return log_dir + def save_model(self, epoch): + if not os.path.exists(self._log_dir + 'models/'): + os.makedirs(self._log_dir + 'models/') + + config = dict(self._kwargs) + config['model_state_dict'] = self.dcrnn_model.state_dict() + config['epoch'] = epoch + torch.save(config, self._log_dir + 'models/epo%d.tar' % epoch) + self._logger.info("Loaded model at {}".format(epoch)) + return self._log_dir + 'models/epo%d.tar' % epoch + + def load_model(self, epoch): + assert os.path.exists( + self._log_dir + 'models/epo%d.tar' % epoch), 'Weights at epoch %d not found' % epoch + checkpoint = torch.load(self._log_dir + 'models/epo%d.tar' % epoch, map_location='cpu') + self.dcrnn_model.load_state_dict(checkpoint['model_state_dict']) + self._logger.info("Loaded model at {}".format(epoch)) + def train(self, **kwargs): kwargs.update(self._train_kwargs) return self._train(**kwargs) @@ -97,12 +115,14 @@ class DCRNNSupervisor: min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1, test_every_n_epochs=10, **kwargs): # steps is used in learning rate - will see if need to use it? + min_val_loss = float('inf') + wait = 0 + batches_seen = 0 optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr) criterion = torch.nn.L1Loss() # mae loss self.dcrnn_model = self.dcrnn_model.train() - batches_seen = 0 self._logger.info('Start training ...') for epoch_num in range(epochs): train_iterator = self._data['train_loader'].get_iterator() @@ -145,6 +165,20 @@ class DCRNNSupervisor: np.mean(losses), test_loss, (end_time - start_time)) self._logger.info(message) + if val_loss < min_val_loss: + wait = 0 + min_val_loss = val_loss + if save_model: + model_file_name = self.save_model(epoch_num) + self._logger.info( + 'Val loss decrease from {:.4f} to {:.4f}, saving to {}'.format(min_val_loss, val_loss, + model_file_name)) + elif val_loss >= min_val_loss: + wait += 1 + if wait == patience: + self._logger.warning('Early stopping at epoch: %d' % epoch_num) + break + def _get_x_y(self, x, y): """ :param x: shape (batch_size, seq_len, num_sensor, input_dim)