cuda no grad
This commit is contained in:
parent
593e3db1bf
commit
5a790d5586
|
|
@ -37,6 +37,7 @@ class DCRNNSupervisor:
|
|||
# setup model
|
||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||
self._logger.info("Model created")
|
||||
|
||||
@staticmethod
|
||||
def _get_log_dir(kwargs):
|
||||
|
|
@ -93,21 +94,22 @@ class DCRNNSupervisor:
|
|||
Computes mean L1Loss
|
||||
:return: mean L1Loss
|
||||
"""
|
||||
self.dcrnn_model = self.dcrnn_model.eval()
|
||||
with torch.no_grad():
|
||||
self.dcrnn_model = self.dcrnn_model.eval()
|
||||
|
||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
losses = []
|
||||
criterion = torch.nn.L1Loss()
|
||||
val_iterator = self._data['{}_loader'.format(dataset)].get_iterator()
|
||||
losses = []
|
||||
criterion = torch.nn.L1Loss()
|
||||
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
for _, (x, y) in enumerate(val_iterator):
|
||||
x, y = self._get_x_y(x, y)
|
||||
x, y = self._get_x_y_in_correct_dims(x, y)
|
||||
|
||||
output = self.dcrnn_model(x)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
losses.append(loss.item())
|
||||
output = self.dcrnn_model(x)
|
||||
loss = self._compute_loss(y, output, criterion)
|
||||
losses.append(loss.item())
|
||||
|
||||
return np.mean(losses)
|
||||
return np.mean(losses)
|
||||
|
||||
def _train(self, base_lr,
|
||||
steps, patience=50, epochs=100,
|
||||
|
|
|
|||
Loading…
Reference in New Issue