Implemented eval and function
This commit is contained in:
parent
20c6aa5862
commit
d9f41172dc
|
|
@ -71,6 +71,27 @@ class DCRNNSupervisor:
|
||||||
kwargs.update(self._train_kwargs)
|
kwargs.update(self._train_kwargs)
|
||||||
return self._train(**kwargs)
|
return self._train(**kwargs)
|
||||||
|
|
||||||
|
def evaluate(self, dataset='val'):
|
||||||
|
"""
|
||||||
|
Computes mean L1Loss
|
||||||
|
:return: mean L1Loss
|
||||||
|
"""
|
||||||
|
self.dcrnn_model = self.dcrnn_model.eval()
|
||||||
|
|
||||||
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
|
losses = []
|
||||||
|
criterion = torch.nn.L1Loss()
|
||||||
|
|
||||||
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
|
x, y = self._get_x_y(x, y)
|
||||||
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||||
|
|
||||||
|
output = self.dcrnn_model(x)
|
||||||
|
loss = self._compute_loss(y, output, criterion)
|
||||||
|
losses.append(loss.item())
|
||||||
|
|
||||||
|
return np.mean(losses)
|
||||||
|
|
||||||
def _train(self, base_lr,
|
def _train(self, base_lr,
|
||||||
steps, patience=50, epochs=100,
|
steps, patience=50, epochs=100,
|
||||||
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
|
min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1,
|
||||||
|
|
@ -79,6 +100,8 @@ class DCRNNSupervisor:
|
||||||
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr)
|
||||||
criterion = torch.nn.L1Loss() # mae loss
|
criterion = torch.nn.L1Loss() # mae loss
|
||||||
|
|
||||||
|
self.dcrnn_model = self.dcrnn_model.train()
|
||||||
|
|
||||||
batches_seen = 0
|
batches_seen = 0
|
||||||
self._logger.info('Start training ...')
|
self._logger.info('Start training ...')
|
||||||
for epoch_num in range(epochs):
|
for epoch_num in range(epochs):
|
||||||
|
|
@ -106,12 +129,20 @@ class DCRNNSupervisor:
|
||||||
|
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
|
val_loss = self.evaluate(dataset='val')
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
if epoch_num % log_every == 0:
|
if epoch_num % log_every == 0:
|
||||||
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} ' \
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, val_mae: {:.4f} ' \
|
||||||
'lr:{:.6f} {:.1f}s'.format(epoch_num, epochs, batches_seen,
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
np.mean(losses), 0.0,
|
np.mean(losses), val_loss,
|
||||||
0.0, (end_time - start_time))
|
(end_time - start_time))
|
||||||
|
self._logger.info(message)
|
||||||
|
|
||||||
|
if epoch_num % test_every_n_epochs == 0:
|
||||||
|
test_loss = self.evaluate(dataset='test')
|
||||||
|
message = 'Epoch [{}/{}] ({}) train_mae: {:.4f}, test_mae: {:.4f} ' \
|
||||||
|
'{:.1f}s'.format(epoch_num, epochs, batches_seen,
|
||||||
|
np.mean(losses), test_loss, (end_time - start_time))
|
||||||
self._logger.info(message)
|
self._logger.info(message)
|
||||||
|
|
||||||
def _get_x_y(self, x, y):
|
def _get_x_y(self, x, y):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue