diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index d496dd5..478cbdf 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -17,6 +17,8 @@ class DCRNNSupervisor: self._model_kwargs = kwargs.get('model') self._train_kwargs = kwargs.get('train') + self.max_grad_norm = self._train_kwargs.get('max_grad_norm', 1.) + # logging. self._log_dir = self._get_log_dir(kwargs) log_level = self._kwargs.get('log_level', 'INFO') @@ -104,11 +106,15 @@ class DCRNNSupervisor: decoder_hidden_state = encoder_hidden_state decoder_input = go_symbol + outputs = [] + for t in range(self.horizon): decoder_output, decoder_hidden_state = self.decoder_model.forward(decoder_input, decoder_hidden_state) decoder_input = decoder_output + outputs.append(decoder_output) + if self.use_curriculum_learning: # todo check for is_training (pytorch way?) c = np.random.uniform(0, 1) if c < self._compute_sampling_threshold(batches_seen): @@ -119,9 +125,16 @@ class DCRNNSupervisor: self._logger.info("Decoder complete, starting backprop") loss.backward() + + # gradient clipping - this does it in place + torch.nn.utils.clip_grad_norm_(self.encoder_model.parameters(), self.max_grad_norm) + torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), self.max_grad_norm) + encoder_optimizer.step() decoder_optimizer.step() - return loss.item() + + outputs = torch.stack(outputs) + return outputs.view(self.horizon, batch_size, self.num_nodes, self.output_dim), loss.item() def _train(self, base_lr, steps, patience=50, epochs=100, @@ -147,8 +160,8 @@ class DCRNNSupervisor: self._logger.debug("y: {}".format(y.size())) x = x.permute(1, 0, 2, 3) y = y.permute(1, 0, 2, 3) - loss = self._train_one_batch(x, y, batches_seen, encoder_optimizer, - decoder_optimizer, criterion) + output, loss = self._train_one_batch(x, y, batches_seen, encoder_optimizer, + decoder_optimizer, criterion) losses.append(loss) batches_seen += 1