logging and refactor
This commit is contained in:
parent
e563e1bf37
commit
31acadedce
|
|
@ -104,8 +104,7 @@ class DCRNNSupervisor:
|
||||||
criterion = torch.nn.L1Loss()
|
criterion = torch.nn.L1Loss()
|
||||||
|
|
||||||
for _, (x, y) in enumerate(val_iterator):
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
x, y = self._get_x_y(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
|
||||||
|
|
||||||
output = self.dcrnn_model(x)
|
output = self.dcrnn_model(x)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
|
|
@ -139,13 +138,12 @@ class DCRNNSupervisor:
|
||||||
for _, (x, y) in enumerate(train_iterator):
|
for _, (x, y) in enumerate(train_iterator):
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
x, y = self._get_x_y(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
|
||||||
|
|
||||||
output = self.dcrnn_model(x, y, batches_seen)
|
output = self.dcrnn_model(x, y, batches_seen)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
|
|
||||||
self._logger.debug(loss.item())
|
self._logger.info(loss.item())
|
||||||
|
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
|
@ -156,6 +154,7 @@ class DCRNNSupervisor:
|
||||||
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
self._logger.info("finished one batch in {}".format(time.time() - start_time))
|
||||||
|
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
|
|
||||||
|
|
@ -190,6 +189,11 @@ class DCRNNSupervisor:
|
||||||
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
self._logger.warning('Early stopping at epoch: %d' % epoch_num)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
def _prepare_data(self, x, y):
|
||||||
|
x, y = self._get_x_y(x, y)
|
||||||
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||||
|
return x.to(device), y.to(device)
|
||||||
|
|
||||||
def _get_x_y(self, x, y):
|
def _get_x_y(self, x, y):
|
||||||
"""
|
"""
|
||||||
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
:param x: shape (batch_size, seq_len, num_sensor, input_dim)
|
||||||
|
|
@ -216,7 +220,7 @@ class DCRNNSupervisor:
|
||||||
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
|
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
|
||||||
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
||||||
self.num_nodes * self.output_dim)
|
self.num_nodes * self.output_dim)
|
||||||
return x.to(device), y.to(device)
|
return x, y
|
||||||
|
|
||||||
def _compute_loss(self, y_true, y_predicted, criterion):
|
def _compute_loss(self, y_true, y_predicted, criterion):
|
||||||
loss = 0
|
loss = 0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue