cuda no grad
This commit is contained in:
parent
593e3db1bf
commit
5a790d5586
|
|
@ -37,6 +37,7 @@ class DCRNNSupervisor:
|
|||
# setup model
|
||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||
self._logger.info("Model created")
|
||||
|
||||
@staticmethod
|
||||
def _get_log_dir(kwargs):
|
||||
|
|
@ -93,6 +94,7 @@ class DCRNNSupervisor:
|
|||
Computes mean L1Loss
|
||||
:return: mean L1Loss
|
||||
"""
|
||||
with torch.no_grad():
|
||||
self.dcrnn_model = self.dcrnn_model.eval()
|
||||
|
||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
|
|
|
|||
Loading…
Reference in New Issue