diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index c919c72..f9cf1a4 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -71,6 +71,27 @@ class DCRNNSupervisor: kwargs.update(self._train_kwargs) return self._train(**kwargs) + def evaluate(self, dataset='val'): + """ + Computes mean L1Loss + :return: mean L1Loss + """ + self.dcrnn_model = self.dcrnn_model.eval() + + val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() + losses = [] + criterion = torch.nn.L1Loss() + + for _, (x, y) in enumerate(val_iterator): + x, y = self._get_x_y(x, y) + x, y = self._get_x_y_in_correct_dims(x, y) + + output = self.dcrnn_model(x) + loss = self._compute_loss(y, output, criterion) + losses.append(loss.item()) + + return np.mean(losses) + def _train(self, base_lr, steps, patience=50, epochs=100, min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1, @@ -79,6 +100,8 @@ class DCRNNSupervisor: optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr) criterion = torch.nn.L1Loss() # mae loss + self.dcrnn_model = self.dcrnn_model.train() + batches_seen = 0 self._logger.info('Start training ...') for epoch_num in range(epochs): @@ -106,12 +129,20 @@ class DCRNNSupervisor: optimizer.step() + val_loss = self.evaluate(dataset='val') end_time = time.time() if epoch_num % log_every == 0: message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} ' \ - 'lr:{:.6f} {:.1f}s'.format(epoch_num, epochs, batches_seen, - np.mean(losses), 0.0, - 0.0, (end_time - start_time)) + '{:.1f}s'.format(epoch_num, epochs, batches_seen, + np.mean(losses), val_loss, + (end_time - start_time)) + self._logger.info(message) + + if epoch_num % test_every_n_epochs == 0: + test_loss = self.evaluate(dataset='test') + message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f} ' \ + '{:.1f}s'.format(epoch_num, epochs, batches_seen, + np.mean(losses), test_loss, (end_time - start_time)) self._logger.info(message) def _get_x_y(self, x, y):