From 593e3db1bf3c9039eeec0d25806321a831e66342 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Fri, 4 Oct 2019 22:45:08 -0400 Subject: [PATCH] Using model.cuda() if cuda is available --- model/pytorch/dcrnn_model.py | 7 +------ model/pytorch/dcrnn_supervisor.py | 5 ++--- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 9aa88c7..20ca2d8 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -1,11 +1,7 @@ -from typing import Any - import numpy as np import torch import torch.nn as nn -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class Seq2SeqAttrs: def __init__(self, adj_mx, **model_kwargs): @@ -49,8 +45,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs): """ batch_size, _ = inputs.size() if hidden_state is None: - hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), - device=device) + hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size)) hidden_states = [] output = inputs for layer_num, dcgru_layer in enumerate(self.dcgru_layers): diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index aa31eb0..81b4615 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -7,8 +7,6 @@ import torch from lib import utils from model.pytorch.dcrnn_model import DCRNNModel -device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - class DCRNNSupervisor: def __init__(self, adj_mx, **kwargs): @@ -37,7 +35,8 @@ class DCRNNSupervisor: self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder # setup model - self.dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) + dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) + self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model @staticmethod def _get_log_dir(kwargs):