Using model.cuda() if cuda is available

This commit is contained in:
Chintan Shah 2019-10-04 22:45:08 -04:00
parent 8d3b1d0d66
commit 593e3db1bf
2 changed files with 3 additions and 9 deletions

View File

@ -1,11 +1,7 @@
from typing import Any
import numpy as np
import torch
import torch.nn as nn
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Seq2SeqAttrs:
def __init__(self, adj_mx, **model_kwargs):
@ -49,8 +45,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
"""
batch_size, _ = inputs.size()
if hidden_state is None:
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
device=device)
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
hidden_states = []
output = inputs
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):

View File

@ -7,8 +7,6 @@ import torch
from lib import utils
from model.pytorch.dcrnn_model import DCRNNModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DCRNNSupervisor:
def __init__(self, adj_mx, **kwargs):
@ -37,7 +35,8 @@ class DCRNNSupervisor:
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
# setup model
self.dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
@staticmethod
def _get_log_dir(kwargs):