logging and refactor
This commit is contained in:
parent
e563e1bf37
commit
31acadedce
|
|
@ -104,8 +104,7 @@ class DCRNNSupervisor:
|
|||
criterion = torch.nn.L1Loss()
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output = self.dcrnn_model(x)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
|
|
@ -139,13 +138,12 @@ class DCRNNSupervisor:
|
|||
for _, (x, y) in enumerate(train_iterator):
|
||||
optimizer.zero_grad()
|
||||
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output = self.dcrnn_model(x, y, batches_seen)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
|
||||
self._logger.debug(loss.item())
|
||||
self._logger.info(loss.item())
|
||||
|
||||
losses.append(loss.item())
|
||||
|
||||
|
|
@ -156,6 +154,7 @@ class DCRNNSupervisor:
|
|||
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
||||
|
||||
optimizer.step()
|
||||
self._logger.info("finished one batch in {}".format(time.time() - start_time))
|
||||
|
||||
lr_scheduler.step()
|
||||
|
||||
|
|
@ -190,6 +189,11 @@ class DCRNNSupervisor:
|
|||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||
break
|
||||
|
||||
def _prepare_data(self, x, y):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
return x.to(device), y.to(device)
|
||||
|
||||
def _get_x_y(self, x, y):
|
||||
"""
|
||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||
|
|
@ -216,7 +220,7 @@ class DCRNNSupervisor:
|
|||
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
|
||||
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
||||
self.num_nodes * self.output_dim)
|
||||
return x.to(device), y.to(device)
|
||||
return x, y
|
||||
|
||||
def _compute_loss(self, y_true, y_predicted, criterion):
|
||||
loss = 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue