moving tensors to GPU
This commit is contained in:
parent
9454fd91a2
commit
ba304e9f04
|
|
@ -3,6 +3,8 @@ import torch
|
|||
|
||||
from lib import utils
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class LayerParams:
|
||||
def __init__(self, rnn_network: torch.nn.Module, layer_type: str):
|
||||
|
|
@ -13,7 +15,8 @@ class LayerParams:
|
|||
|
||||
def get_weights(self, shape):
|
||||
if shape not in self._params_dict:
|
||||
nn_param = torch.nn.Parameter(torch.nn.init.xavier_normal(torch.empty(*shape)))
|
||||
nn_param = torch.nn.Parameter(torch.empty(*shape, device=device))
|
||||
torch.nn.init.xavier_normal_(nn_param)
|
||||
self._params_dict[shape] = nn_param
|
||||
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||
nn_param)
|
||||
|
|
@ -21,7 +24,8 @@ class LayerParams:
|
|||
|
||||
def get_biases(self, length, bias_start=0.0):
|
||||
if length not in self._biases_dict:
|
||||
biases = torch.nn.Parameter(torch.nn.init.constant(torch.empty(length), bias_start))
|
||||
biases = torch.nn.Parameter(torch.empty(length, device=device))
|
||||
torch.nn.init.constant_(biases, bias_start)
|
||||
self._biases_dict[length] = biases
|
||||
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||
biases)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import torch
|
|||
from lib import utils
|
||||
from model.pytorch.dcrnn_model import DCRNNModel
|
||||
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
|
||||
class DCRNNSupervisor:
|
||||
def __init__(self, adj_mx, **kwargs):
|
||||
|
|
@ -36,7 +38,6 @@ class DCRNNSupervisor:
|
|||
|
||||
# setup model
|
||||
dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs)
|
||||
print(dcrnn_model)
|
||||
self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model
|
||||
self._logger.info("Model created")
|
||||
|
||||
|
|
@ -215,7 +216,7 @@ class DCRNNSupervisor:
|
|||
x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim)
|
||||
y = y[..., :self.output_dim].view(self.horizon, batch_size,
|
||||
self.num_nodes * self.output_dim)
|
||||
return x, y
|
||||
return x.to(device), y.to(device)
|
||||
|
||||
def _compute_loss(self, y_true, y_predicted, criterion):
|
||||
loss = 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue