From 20c6aa5862163984313cda19a4a4913cc72fb981 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Fri, 4 Oct 2019 16:05:52 -0400 Subject: [PATCH] Fixed bugs with refactoring --- model/pytorch/dcrnn_model.py | 9 +++++++-- model/pytorch/dcrnn_supervisor.py | 7 +------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 5c04ac8..9aa88c7 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -67,7 +67,6 @@ class DecoderModel(nn.Module, Seq2SeqAttrs): nn.Module.__init__(self) Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) self.output_dim = int(model_kwargs.get('output_dim', 1)) - self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder self.projection_layer = nn.Linear(self.hidden_state_size, self.num_nodes * self.output_dim) self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.output_dim, @@ -105,8 +104,14 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) self.encoder_model = EncoderModel(adj_mx, **model_kwargs) self.decoder_model = DecoderModel(adj_mx, **model_kwargs) + self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) + self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) self._logger = logger + def _compute_sampling_threshold(self, batches_seen): + return self.cl_decay_steps / ( + self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) + def encoder(self, inputs): """ encoder forward pass on t time steps @@ -128,7 +133,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) """ batch_size = encoder_hidden_state.size(1) - go_symbol = torch.zeros((batch_size, self.num_nodes * self.output_dim)) + go_symbol = torch.zeros((batch_size, self.num_nodes * self.decoder_model.output_dim)) decoder_hidden_state = encoder_hidden_state decoder_input = go_symbol diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 6242dfd..c919c72 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -32,7 +32,6 @@ class DCRNNSupervisor: self.input_dim = int(self._model_kwargs.get('input_dim', 1)) self.seq_len = int(self._model_kwargs.get('seq_len')) # for the encoder self.output_dim = int(self._model_kwargs.get('output_dim', 1)) - self.cl_decay_steps = int(self._model_kwargs.get('cl_decay_steps', 1000)) self.use_curriculum_learning = bool( self._model_kwargs.get('use_curriculum_learning', False)) self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder @@ -94,7 +93,7 @@ class DCRNNSupervisor: x, y = self._get_x_y(x, y) x, y = self._get_x_y_in_correct_dims(x, y) - output = self.dcrnn_model(x, y) + output = self.dcrnn_model(x, y, batches_seen) loss = self._compute_loss(y, output, criterion) self._logger.info(loss.item()) losses.append(loss.item()) @@ -143,10 +142,6 @@ class DCRNNSupervisor: self.num_nodes * self.output_dim) return x, y - def _compute_sampling_threshold(self, batches_seen): - return self.cl_decay_steps / ( - self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) - def _compute_loss(self, y_true, y_predicted, criterion): loss = 0 for t in range(self.horizon):