diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 0497323..c2a34d4 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -131,8 +131,8 @@ class DCRNNSupervisor: loss = self._compute_loss(y, output) losses.append(loss.item()) - y_truths.append(y) - y_preds.append(output) + y_truths.append(y.cpu()) + y_preds.append(output.cpu()) mean_loss = np.mean(losses)