Using model.cuda() if cuda is available

This commit is contained in:
Chintan Shah 2019-10-04 22:45:08 -04:00
parent 8d3b1d0d66
commit 593e3db1bf
2 changed files with 3 additions and 9 deletions

View File

@ -1,11 +1,7 @@
from typing import Any
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Seq2SeqAttrs: class Seq2SeqAttrs:
def __init__(self, adj_mx, **model_kwargs): def __init__(self, adj_mx, **model_kwargs):
@ -49,8 +45,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
""" """
batch_size, _ = inputs.size() batch_size, _ = inputs.size()
if hidden_state is None: if hidden_state is None:
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
device=device)
hidden_states = [] hidden_states = []
output = inputs output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers): for layer_num, dcgru_layer in enumerate(self.dcgru_layers):

View File

@ -7,8 +7,6 @@ import torch
from lib import utils from lib import utils
from model.pytorch.dcrnn_model import DCRNNModel from model.pytorch.dcrnn_model import DCRNNModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DCRNNSupervisor: class DCRNNSupervisor:
def __init__(self, adj_mx, **kwargs): def __init__(self, adj_mx, **kwargs):
@ -37,7 +35,8 @@ class DCRNNSupervisor:
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
# setup model # setup model
self.dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
@staticmethod @staticmethod
def _get_log_dir(kwargs): def _get_log_dir(kwargs):