cuda no grad

This commit is contained in:
Chintan Shah 2019-10-04 23:30:10 -04:00
parent 593e3db1bf
commit 5a790d5586
1 changed files with 13 additions and 11 deletions

View File

@ -37,6 +37,7 @@ class DCRNNSupervisor:
# setup model
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
self._logger.info("Model created")
@staticmethod
def _get_log_dir(kwargs):
@ -93,21 +94,22 @@ class DCRNNSupervisor:
Computes mean L1Loss
:return: mean L1Loss
"""
self.dcrnn_model = self.dcrnn_model.eval()
with torch.no_grad():
self.dcrnn_model = self.dcrnn_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
losses = []
criterion = torch.nn.L1Loss()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
losses = []
criterion = torch.nn.L1Loss()
for _, (x, y) in enumerate(val_iterator):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
for _, (x, y) in enumerate(val_iterator):
x, y = self._get_x_y(x, y)
x, y = self._get_x_y_in_correct_dims(x, y)
output = self.dcrnn_model(x)
loss = self._compute_loss(y, output, criterion)
losses.append(loss.item())
output = self.dcrnn_model(x)
loss = self._compute_loss(y, output, criterion)
losses.append(loss.item())
return np.mean(losses)
return np.mean(losses)
def _train(self, base_lr,
steps, patience=50, epochs=100,