logging and refactor

This commit is contained in:
Chintan Shah 2019-10-06 14:29:28 -04:00
parent e563e1bf37
commit 31acadedce
1 changed files with 10 additions and 6 deletions

View File

@ -104,8 +104,7 @@ class DCRNNSupervisor:
criterion = torch.nn.L1Loss() criterion = torch.nn.L1Loss()
for _, (x, y) in enumerate(val_iterator): for _, (x, y) in enumerate(val_iterator):
x, y = self._get_x_y(x, y) x, y = self._prepare_data(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
output = self.dcrnn_model(x) output = self.dcrnn_model(x)
loss = self._compute_loss(y, output, criterion) loss = self._compute_loss(y, output, criterion)
@ -139,13 +138,12 @@ class DCRNNSupervisor:
for _, (x, y) in enumerate(train_iterator): for _, (x, y) in enumerate(train_iterator):
optimizer.zero_grad() optimizer.zero_grad()
x, y = self._get_x_y(x, y) x, y = self._prepare_data(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
output = self.dcrnn_model(x, y, batches_seen) output = self.dcrnn_model(x, y, batches_seen)
loss = self._compute_loss(y, output, criterion) loss = self._compute_loss(y, output, criterion)
self._logger.debug(loss.item()) self._logger.info(loss.item())
losses.append(loss.item()) losses.append(loss.item())
@ -156,6 +154,7 @@ class DCRNNSupervisor:
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
optimizer.step() optimizer.step()
self._logger.info("finished one batch in {}".format(time.time() - start_time))
lr_scheduler.step() lr_scheduler.step()
@ -190,6 +189,11 @@ class DCRNNSupervisor:
self._logger.warning('Early stopping at epoch: %d' % epoch_num) self._logger.warning('Early stopping at epoch: %d' % epoch_num)
break break
def _prepare_data(self, x, y):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
return x.to(device), y.to(device)
def _get_x_y(self, x, y): def _get_x_y(self, x, y):
""" """
:param x: shape (batch_size, seq_len, num_sensor, input_dim) :param x: shape (batch_size, seq_len, num_sensor, input_dim)
@ -216,7 +220,7 @@ class DCRNNSupervisor:
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
y = y[..., :self.output_dim].view(self.horizon, batch_size, y = y[..., :self.output_dim].view(self.horizon, batch_size,
self.num_nodes * self.output_dim) self.num_nodes * self.output_dim)
return x.to(device), y.to(device) return x, y
def _compute_loss(self, y_true, y_predicted, criterion): def _compute_loss(self, y_true, y_predicted, criterion):
loss = 0 loss = 0