Added per timestep loss
This commit is contained in:
parent
3d93008a3e
commit
2560e1d954
|
|
@ -109,10 +109,20 @@ class DCRNNSupervisor:
|
|||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
losses = []
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
per_timestep_loss = torch.zeros(12) # hardcoded batch size, horizon
|
||||
num_batches = 0
|
||||
|
||||
for batch_i, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._prepare_data(x, y)
|
||||
|
||||
output = self.dcrnn_model(x)
|
||||
|
||||
# (horizon, batch_size, num_sensor * output_dim)
|
||||
for t in y.size(0):
|
||||
per_timestep_loss[t] += self._compute_loss(y[t], output[t])
|
||||
|
||||
num_batches += 1
|
||||
|
||||
loss = self._compute_loss(y, output)
|
||||
losses.append(loss.item())
|
||||
|
||||
|
|
@ -120,6 +130,11 @@ class DCRNNSupervisor:
|
|||
|
||||
self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen)
|
||||
|
||||
per_timestep_loss /= num_batches
|
||||
|
||||
for i, val in enumerate(per_timestep_loss):
|
||||
self._logger.info("Dataset:{}, Timestep: {}, MAE:{}".format(dataset, i, val.item()))
|
||||
|
||||
return mean_loss
|
||||
|
||||
def _train(self, base_lr,
|
||||
|
|
@ -144,6 +159,9 @@ class DCRNNSupervisor:
|
|||
batches_seen = num_batches * self._epoch_num
|
||||
|
||||
for epoch_num in range(self._epoch_num, epochs):
|
||||
|
||||
self.dcrnn_model = self.dcrnn_model.train()
|
||||
|
||||
train_iterator = self._data['train_loader'].get_iterator()
|
||||
losses = []
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue