Using model.cuda() if cuda is available
This commit is contained in:
parent
8d3b1d0d66
commit
593e3db1bf
|
|
@ -1,11 +1,7 @@
|
|||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class Seq2SeqAttrs:
|
||||
def __init__(self, adj_mx, **model_kwargs):
|
||||
|
|
@ -49,8 +45,7 @@ class EncoderModel(nn.Module, Seq2SeqAttrs):
|
|||
"""
|
||||
batch_size, _ = inputs.size()
|
||||
if hidden_state is None:
|
||||
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size),
|
||||
device=device)
|
||||
hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size))
|
||||
hidden_states = []
|
||||
output = inputs
|
||||
for layer_num, dcgru_layer in enumerate(self.dcgru_layers):
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import torch
|
|||
from lib import utils
|
||||
from model.pytorch.dcrnn_model import DCRNNModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class DCRNNSupervisor:
|
||||
def __init__(self, adj_mx, **kwargs):
|
||||
|
|
@ -37,7 +35,8 @@ class DCRNNSupervisor:
|
|||
self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder
|
||||
|
||||
# setup model
|
||||
self.dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||
|
||||
@staticmethod
|
||||
def _get_log_dir(kwargs):
|
||||
|
|
|
|||
Loading…
Reference in New Issue