diff --git a/model/pytorch/dcrnn_cell.py b/model/pytorch/dcrnn_cell.py index 0779a72..2a0b85f 100644 --- a/model/pytorch/dcrnn_cell.py +++ b/model/pytorch/dcrnn_cell.py @@ -3,6 +3,8 @@ import torch from lib import utils +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class LayerParams: def __init__(self, rnn_network: torch.nn.Module, layer_type: str): @@ -13,7 +15,8 @@ class LayerParams: def get_weights(self, shape): if shape not in self._params_dict: - nn_param = torch.nn.Parameter(torch.nn.init.xavier_normal(torch.empty(*shape))) + nn_param = torch.nn.Parameter(torch.empty(*shape, device=device)) + torch.nn.init.xavier_normal_(nn_param) self._params_dict[shape] = nn_param self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)), nn_param) @@ -21,7 +24,8 @@ class LayerParams: def get_biases(self, length, bias_start=0.0): if length not in self._biases_dict: - biases = torch.nn.Parameter(torch.nn.init.constant(torch.empty(length), bias_start)) + biases = torch.nn.Parameter(torch.empty(length, device=device)) + torch.nn.init.constant_(biases, bias_start) self._biases_dict[length] = biases self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)), biases) diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 67c8440..6f3cdd2 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -7,6 +7,8 @@ import torch from lib import utils from model.pytorch.dcrnn_model import DCRNNModel +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + class DCRNNSupervisor: def __init__(self, adj_mx, **kwargs): @@ -36,7 +38,6 @@ class DCRNNSupervisor: # setup model dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) - print(dcrnn_model) self.dcrnn_model = dcrnn_model.cuda() if torch.cuda.is_available() else dcrnn_model self._logger.info("Model created") @@ -215,7 +216,7 @@ class DCRNNSupervisor: x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) y = y[..., :self.output_dim].view(self.horizon, batch_size, self.num_nodes * self.output_dim) - return x, y + return x.to(device), y.to(device) def _compute_loss(self, y_true, y_predicted, criterion): loss = 0