From 2560e1d954d37d3a221273ba0f452dedc8543eae Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Mon, 7 Oct 2019 20:03:00 -0400 Subject: [PATCH] Added per timestep loss --- model/pytorch/dcrnn_supervisor.py | 20 +++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index cb66455..af5732d 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -109,10 +109,20 @@ class DCRNNSupervisor: val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() losses = [] - for _, (x, y) in enumerate(val_iterator): + per_timestep_loss = torch.zeros(12) # hardcoded batch size, horizon + num_batches = 0 + + for batch_i, (x, y) in enumerate(val_iterator): x, y = self._prepare_data(x, y) output = self.dcrnn_model(x) + + # (horizon, batch_size, num_sensor * output_dim) + for t in y.size(0): + per_timestep_loss[t] += self._compute_loss(y[t], output[t]) + + num_batches += 1 + loss = self._compute_loss(y, output) losses.append(loss.item()) @@ -120,6 +130,11 @@ class DCRNNSupervisor: self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen) + per_timestep_loss /= num_batches + + for i, val in enumerate(per_timestep_loss): + self._logger.info("Dataset:{}, Timestep: {}, MAE:{}".format(dataset, i, val.item())) + return mean_loss def _train(self, base_lr, @@ -144,6 +159,9 @@ class DCRNNSupervisor: batches_seen = num_batches * self._epoch_num for epoch_num in range(self._epoch_num, epochs): + + self.dcrnn_model = self.dcrnn_model.train() + train_iterator = self._data['train_loader'].get_iterator() losses = []