moving tensors to GPU

This commit is contained in:
Chintan Shah 2019-10-06 14:00:54 -04:00
parent 9454fd91a2
commit ba304e9f04
2 changed files with 9 additions and 4 deletions

View File

@ -3,6 +3,8 @@ import torch
from lib import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LayerParams:
def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
@ -13,7 +15,8 @@ class LayerParams:
def get_weights(self, shape):
if shape not in self._params_dict:
nn_param = torch.nn.Parameter(torch.nn.init.xavier_normal(torch.empty(*shape)))
nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
torch.nn.init.xavier_normal_(nn_param)
self._params_dict[shape] = nn_param
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
nn_param)
@ -21,7 +24,8 @@ class LayerParams:
def get_biases(self, length, bias_start=0.0):
if length not in self._biases_dict:
biases = torch.nn.Parameter(torch.nn.init.constant(torch.empty(length), bias_start))
biases = torch.nn.Parameter(torch.empty(length, device=device))
torch.nn.init.constant_(biases, bias_start)
self._biases_dict[length] = biases
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
biases)

View File

@ -7,6 +7,8 @@ import torch
from lib import utils
from model.pytorch.dcrnn_model import DCRNNModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DCRNNSupervisor:
def __init__(self, adj_mx, **kwargs):
@ -36,7 +38,6 @@ class DCRNNSupervisor:
# setup model
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
print(dcrnn_model)
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
self._logger.info("Model created")
@ -215,7 +216,7 @@ class DCRNNSupervisor:
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
y = y[..., :self.output_dim].view(self.horizon, batch_size,
self.num_nodes * self.output_dim)
return x, y
return x.to(device), y.to(device)
def _compute_loss(self, y_true, y_predicted, criterion):
loss = 0