removed logging of every horizon
This commit is contained in:
parent
765142de00
commit
02fb2430f0
|
|
@ -109,20 +109,10 @@ class DCRNNSupervisor:
|
||||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
per_timestep_loss = torch.zeros(12) # hardcoded batch size, horizon
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
num_batches = 0
|
|
||||||
|
|
||||||
for batch_i, (x, y) in enumerate(val_iterator):
|
|
||||||
x, y = self._prepare_data(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
||||||
output = self.dcrnn_model(x)
|
output = self.dcrnn_model(x)
|
||||||
|
|
||||||
# (horizon, batch_size, num_sensor * output_dim)
|
|
||||||
for t in range(y.size(0)):
|
|
||||||
per_timestep_loss[t] += self._compute_loss(y[t], output[t])
|
|
||||||
|
|
||||||
num_batches += 1
|
|
||||||
|
|
||||||
loss = self._compute_loss(y, output)
|
loss = self._compute_loss(y, output)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
|
@ -130,11 +120,6 @@ class DCRNNSupervisor:
|
||||||
|
|
||||||
self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen)
|
self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen)
|
||||||
|
|
||||||
per_timestep_loss /= num_batches
|
|
||||||
|
|
||||||
for i, val in enumerate(per_timestep_loss):
|
|
||||||
self._logger.info("Dataset:{}, Timestep: {}, MAE:{:.4f}".format(dataset, i, val.item()))
|
|
||||||
|
|
||||||
return mean_loss
|
return mean_loss
|
||||||
|
|
||||||
def _train(self, base_lr,
|
def _train(self, base_lr,
|
||||||
|
|
@ -148,8 +133,6 @@ class DCRNNSupervisor:
|
||||||
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=steps,
|
||||||
gamma=lr_decay_ratio)
|
gamma=lr_decay_ratio)
|
||||||
|
|
||||||
self.dcrnn_model = self.dcrnn_model.train()
|
|
||||||
|
|
||||||
self._logger.info('Start training ...')
|
self._logger.info('Start training ...')
|
||||||
|
|
||||||
# this will fail if model is loaded with a changed batch_size
|
# this will fail if model is loaded with a changed batch_size
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue