cuda no grad
This commit is contained in:
parent
593e3db1bf
commit
5a790d5586
|
|
@ -37,6 +37,7 @@ class DCRNNSupervisor:
|
||||||
# setup model
|
# setup model
|
||||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||||
|
self._logger.info("Model created")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _get_log_dir(kwargs):
|
def _get_log_dir(kwargs):
|
||||||
|
|
@ -93,21 +94,22 @@ class DCRNNSupervisor:
|
||||||
Computes mean L1Loss
|
Computes mean L1Loss
|
||||||
:return: mean L1Loss
|
:return: mean L1Loss
|
||||||
"""
|
"""
|
||||||
self.dcrnn_model = self.dcrnn_model.eval()
|
with torch.no_grad():
|
||||||
|
self.dcrnn_model = self.dcrnn_model.eval()
|
||||||
|
|
||||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||||
losses = []
|
losses = []
|
||||||
criterion = torch.nn.L1Loss()
|
criterion = torch.nn.L1Loss()
|
||||||
|
|
||||||
for _, (x, y) in enumerate(val_iterator):
|
for _, (x, y) in enumerate(val_iterator):
|
||||||
x, y = self._get_x_y(x, y)
|
x, y = self._get_x_y(x, y)
|
||||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||||
|
|
||||||
output = self.dcrnn_model(x)
|
output = self.dcrnn_model(x)
|
||||||
loss = self._compute_loss(y, output, criterion)
|
loss = self._compute_loss(y, output, criterion)
|
||||||
losses.append(loss.item())
|
losses.append(loss.item())
|
||||||
|
|
||||||
return np.mean(losses)
|
return np.mean(losses)
|
||||||
|
|
||||||
def _train(self, base_lr,
|
def _train(self, base_lr,
|
||||||
steps, patience=50, epochs=100,
|
steps, patience=50, epochs=100,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue