cuda no grad

This commit is contained in:
Chintan Shah 2019-10-04 23:30:10 -04:00
parent 593e3db1bf
commit 5a790d5586
1 changed files with 13 additions and 11 deletions

View File

@ -37,6 +37,7 @@ class DCRNNSupervisor:
# setup model # setup model
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
self._logger.info("Model created")
@staticmethod @staticmethod
def _get_log_dir(kwargs): def _get_log_dir(kwargs):
@ -93,6 +94,7 @@ class DCRNNSupervisor:
Computes mean L1Loss Computes mean L1Loss
:return: mean L1Loss :return: mean L1Loss
""" """
with torch.no_grad():
self.dcrnn_model = self.dcrnn_model.eval() self.dcrnn_model = self.dcrnn_model.eval()
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator() val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()