returning predictions from the model during eval at every timestep
This commit is contained in:
parent
46b552e075
commit
dda7013f07
|
|
@ -89,11 +89,23 @@ class DCRNNSupervisor:
|
||||||
return 'models/epo%d.tar' % epoch
|
return 'models/epo%d.tar' % epoch
|
||||||
|
|
||||||
def load_model(self):
|
def load_model(self):
|
||||||
|
self._setup_graph()
|
||||||
assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num
|
assert os.path.exists('models/epo%d.tar' % self._epoch_num), 'Weights at epoch %d not found' % self._epoch_num
|
||||||
checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu')
|
checkpoint = torch.load('models/epo%d.tar' % self._epoch_num, map_location='cpu')
|
||||||
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
self.dcrnn_model.load_state_dict(checkpoint['model_state_dict'])
|
||||||
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
self._logger.info("Loaded model at {}".format(self._epoch_num))
|
||||||
|
|
||||||
|
def _setup_graph(self):
|
||||||
|
with torch.no_grad():
|
||||||
|
self.dcrnn_model = self.dcrnn_model.eval()
|
||||||
|
|
||||||
|
val_iterator = self._data['val_loader'].get_iterator()
|
||||||
|
|
||||||
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
|
x, y = self._prepare_data(x, y)
|
||||||
|
output = self.dcrnn_model(x)
|
||||||
|
break
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
kwargs.update(self._train_kwargs)
|
kwargs.update(self._train_kwargs)
|
||||||
return self._train(**kwargs)
|
return self._train(**kwargs)
|
||||||
|
|
@ -109,6 +121,9 @@ class DCRNNSupervisor:
|
||||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
losses = []
|
losses = []
|
||||||
|
|
||||||
|
y_truths = []
|
||||||
|
y_preds = []
|
||||||
|
|
||||||
for _, (x, y) in enumerate(val_iterator):
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
x, y = self._prepare_data(x, y)
|
x, y = self._prepare_data(x, y)
|
||||||
|
|
||||||
|
|
@ -116,11 +131,25 @@ class DCRNNSupervisor:
|
||||||
loss = self._compute_loss(y, output)
|
loss = self._compute_loss(y, output)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
y_truths.append(y)
|
||||||
|
y_preds.append(output)
|
||||||
|
|
||||||
mean_loss = np.mean(losses)
|
mean_loss = np.mean(losses)
|
||||||
|
|
||||||
self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen)
|
self._writer.add_scalar('{} loss'.format(dataset), mean_loss, batches_seen)
|
||||||
|
|
||||||
return mean_loss
|
y_preds = np.concatenate(y_preds, axis=1)
|
||||||
|
y_truths = np.concatenate(y_truths, axis=1) # concatenate on batch dimension
|
||||||
|
|
||||||
|
y_truths_scaled = []
|
||||||
|
y_preds_scaled = []
|
||||||
|
for t in range(y_preds.shape[0]):
|
||||||
|
y_truth = self.standard_scaler.inverse_transform(y_truths[t])
|
||||||
|
y_pred = self.standard_scaler.inverse_transform(y_preds[t])
|
||||||
|
y_truths_scaled.append(y_truth)
|
||||||
|
y_preds_scaled.append(y_pred)
|
||||||
|
|
||||||
|
return mean_loss, {'prediction': y_preds_scaled, 'truth': y_truths_scaled}
|
||||||
|
|
||||||
def _train(self, base_lr,
|
def _train(self, base_lr,
|
||||||
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
|
steps, patience=50, epochs=100, lr_decay_ratio=0.1, log_every=1, save_model=1,
|
||||||
|
|
@ -178,7 +207,7 @@ class DCRNNSupervisor:
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
self._logger.info("evaluating now!")
|
self._logger.info("evaluating now!")
|
||||||
|
|
||||||
val_loss = self.evaluate(dataset='val', batches_seen=batches_seen)
|
val_loss, _ = self.evaluate(dataset='val', batches_seen=batches_seen)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
|
@ -194,7 +223,7 @@ class DCRNNSupervisor:
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
|
if (epoch_num % test_every_n_epochs) == test_every_n_epochs - 1:
|
||||||
test_loss = self.evaluate(dataset='test', batches_seen=batches_seen)
|
test_loss, _ = self.evaluate(dataset='test', batches_seen=batches_seen)
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f}, lr: {:.6f}, ' \
|
||||||
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
np.mean(losses), test_loss, lr_scheduler.get_lr()[0],
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue