From 017ec70783f8247e1f70a5ac203970530290423a Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Sun, 6 Oct 2019 14:10:20 -0400 Subject: [PATCH] moving tensors to GPU [v2] --- model/pytorch/dcrnn_model.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 4a14334..9390c6e 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -4,6 +4,8 @@ import torch.nn as nn from model.pytorch.dcrnn_cell import DCGRUCell +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class Seq2SeqAttrs: def __init__(self, adj_mx, **model_kwargs): @@ -40,7 +42,8 @@ class EncoderModel(nn.Module, Seq2SeqAttrs): """ batch_size, _ = inputs.size() if hidden_state is None: - hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size)) + hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.hidden_state_size), + device=device) hidden_states = [] output = inputs for layer_num, dcgru_layer in enumerate(self.dcgru_layers): @@ -122,7 +125,8 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) """ batch_size = encoder_hidden_state.size(1) - go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim)) + go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim), + device=device) decoder_hidden_state = encoder_hidden_state decoder_input = go_symbol