cuda no grad
This commit is contained in:
parent
593e3db1bf
commit
5a790d5586
|
|
@ -37,6 +37,7 @@ class DCRNNSupervisor:
|
||||||
# setup model
|
# setup model
|
||||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||||
|
self._logger.info("Model created")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_log_dir(kwargs):
|
def _get_log_dir(kwargs):
|
||||||
|
|
@ -93,6 +94,7 @@ class DCRNNSupervisor:
|
||||||
Computes mean L1Loss
|
Computes mean L1Loss
|
||||||
:return: mean L1Loss
|
:return: mean L1Loss
|
||||||
"""
|
"""
|
||||||
|
with torch.no_grad():
|
||||||
self.dcrnn_model = self.dcrnn_model.eval()
|
self.dcrnn_model = self.dcrnn_model.eval()
|
||||||
|
|
||||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue