From 31acadedce9a17fe97148b3dee783d48e989654f Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Sun, 6 Oct 2019 14:29:28 -0400 Subject: [PATCH] logging and refactor --- model/pytorch/dcrnn_supervisor.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 6f3cdd2..592454b 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -104,8 +104,7 @@ class DCRNNSupervisor: criterion = torch.nn.L1Loss() for _, (x, y) in enumerate(val_iterator): - x, y = self._get_x_y(x, y) - x, y = self._get_x_y_in_correct_dims(x, y) + x, y = self._prepare_data(x, y) output = self.dcrnn_model(x) loss = self._compute_loss(y, output, criterion) @@ -139,13 +138,12 @@ class DCRNNSupervisor: for _, (x, y) in enumerate(train_iterator): optimizer.zero_grad() - x, y = self._get_x_y(x, y) - x, y = self._get_x_y_in_correct_dims(x, y) + x, y = self._prepare_data(x, y) output = self.dcrnn_model(x, y, batches_seen) loss = self._compute_loss(y, output, criterion) - self._logger.debug(loss.item()) + self._logger.info(loss.item()) losses.append(loss.item()) @@ -156,6 +154,7 @@ class DCRNNSupervisor: torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) optimizer.step() + self._logger.info("finished one batch in {}".format(time.time() - start_time)) lr_scheduler.step() @@ -190,6 +189,11 @@ class DCRNNSupervisor: self._logger.warning('Early stopping at epoch: %d' % epoch_num) break + def _prepare_data(self, x, y): + x, y = self._get_x_y(x, y) + x, y = self._get_x_y_in_correct_dims(x, y) + return x.to(device), y.to(device) + def _get_x_y(self, x, y): """ :param x: shape (batch_size, seq_len, num_sensor, input_dim) @@ -216,7 +220,7 @@ class DCRNNSupervisor: x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) y = y[..., :self.output_dim].view(self.horizon, batch_size, self.num_nodes * self.output_dim) - return x.to(device), y.to(device) + return x, y def _compute_loss(self, y_true, y_predicted, criterion): loss = 0