From 2b8d5e6b31786af41e1367479eed4a4b142917cb Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Fri, 4 Oct 2019 13:02:50 -0400 Subject: [PATCH] Refactored code and moved everything into a DCRNN forward pass --- model/pytorch/dcrnn_model.py | 92 ++++++++++++++++----- model/pytorch/dcrnn_supervisor.py | 131 ++++++++++++------------------ 2 files changed, 127 insertions(+), 96 deletions(-) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index aaf59c5..5c04ac8 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -1,13 +1,15 @@ +from typing import Any + +import numpy as np import torch import torch.nn as nn device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -class DCRNNModel: - def __init__(self, is_training, adj_mx, **model_kwargs): +class Seq2SeqAttrs: + def __init__(self, adj_mx, **model_kwargs): self.adj_mx = adj_mx - self.is_training = is_training self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) self.filter_type = model_kwargs.get('filter_type', 'laplacian') @@ -18,12 +20,12 @@ class DCRNNModel: self.hidden_state_size = self.num_nodes * self.rnn_units -class EncoderModel(nn.Module, DCRNNModel): - def __init__(self, is_training, adj_mx, **model_kwargs): +class EncoderModel(nn.Module, Seq2SeqAttrs): + def __init__(self, adj_mx, **model_kwargs): # super().__init__(is_training, adj_mx, **model_kwargs) # https://pytorch.org/docs/stable/nn.html#gru nn.Module.__init__(self) - DCRNNModel.__init__(self, is_training, adj_mx, **model_kwargs) + Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) self.input_dim = int(model_kwargs.get('input_dim', 1)) self.seq_len = int(model_kwargs.get('seq_len')) # for the encoder self.dcgru_layers = nn.ModuleList([nn.GRUCell(input_size=self.num_nodes * self.input_dim, @@ -59,21 +61,11 @@ class EncoderModel(nn.Module, DCRNNModel): return output, torch.stack(hidden_states) # runs in O(num_layers) so not too slow -class DecoderModel(nn.Module, DCRNNModel): - def __init__(self, is_training, adj_mx, **model_kwargs): +class DecoderModel(nn.Module, Seq2SeqAttrs): + def __init__(self, adj_mx, **model_kwargs): # super().__init__(is_training, adj_mx, **model_kwargs) nn.Module.__init__(self) - DCRNNModel.__init__(self, is_training, adj_mx, **model_kwargs) - self.adj_mx = adj_mx - self.is_training = is_training - self.max_diffusion_step = int(model_kwargs.get('max_diffusion_step', 2)) - self.cl_decay_steps = int(model_kwargs.get('cl_decay_steps', 1000)) - self.filter_type = model_kwargs.get('filter_type', 'laplacian') - # self.max_grad_norm = float(model_kwargs.get('max_grad_norm', 5.0)) - self.num_nodes = int(model_kwargs.get('num_nodes', 1)) - self.num_rnn_layers = int(model_kwargs.get('num_rnn_layers', 1)) - self.rnn_units = int(model_kwargs.get('rnn_units')) - self.hidden_state_size = self.num_nodes * self.rnn_units + Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) self.output_dim = int(model_kwargs.get('output_dim', 1)) self.use_curriculum_learning = bool(model_kwargs.get('use_curriculum_learning', False)) self.horizon = int(model_kwargs.get('horizon', 1)) # for the decoder @@ -105,3 +97,65 @@ class DecoderModel(nn.Module, DCRNNModel): output = next_hidden_state return self.projection_layer(output), torch.stack(hidden_states) + + +class DCRNNModel(nn.Module, Seq2SeqAttrs): + def __init__(self, adj_mx, logger, **model_kwargs): + super().__init__() + Seq2SeqAttrs.__init__(self, adj_mx, **model_kwargs) + self.encoder_model = EncoderModel(adj_mx, **model_kwargs) + self.decoder_model = DecoderModel(adj_mx, **model_kwargs) + self._logger = logger + + def encoder(self, inputs): + """ + encoder forward pass on t time steps + :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) + :return: encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) + """ + encoder_hidden_state = None + for t in range(self.encoder_model.seq_len): + _, encoder_hidden_state = self.encoder_model(inputs[t], encoder_hidden_state) + + return encoder_hidden_state + + def decoder(self, encoder_hidden_state, labels=None, batches_seen=None): + """ + Decoder forward pass + :param encoder_hidden_state: (num_layers, batch_size, self.hidden_state_size) + :param labels: (self.horizon, batch_size, self.num_nodes * self.output_dim) [optional, not exist for inference] + :param batches_seen: global step [optional, not exist for inference] + :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) + """ + batch_size = encoder_hidden_state.size(1) + go_symbol = torch.zeros((batch_size, self.num_nodes * self.output_dim)) + decoder_hidden_state = encoder_hidden_state + decoder_input = go_symbol + + outputs = [] + + for t in range(self.decoder_model.horizon): + decoder_output, decoder_hidden_state = self.decoder_model(decoder_input, + decoder_hidden_state) + decoder_input = decoder_output + outputs.append(decoder_output) + if self.training and self.use_curriculum_learning: + c = np.random.uniform(0, 1) + if c < self._compute_sampling_threshold(batches_seen): + decoder_input = labels[t] + outputs = torch.stack(outputs) + return outputs + + def forward(self, inputs, labels=None, batches_seen=None): + """ + seq2seq forward pass + :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) + :param labels: shape (horizon, batch_size, num_sensor * output) + :param batches_seen: batches seen till date + :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) + """ + encoder_hidden_state = self.encoder(inputs) + self._logger.info("Encoder complete, starting decoder") + outputs = self.decoder(encoder_hidden_state, labels, batches_seen=batches_seen) + self._logger.info("Decoder complete") + return outputs diff --git a/model/pytorch/dcrnn_supervisor.py b/model/pytorch/dcrnn_supervisor.py index 478cbdf..6242dfd 100644 --- a/model/pytorch/dcrnn_supervisor.py +++ b/model/pytorch/dcrnn_supervisor.py @@ -5,7 +5,7 @@ import numpy as np import torch from lib import utils -from model.pytorch.dcrnn_model import EncoderModel, DecoderModel +from model.pytorch.dcrnn_model import DCRNNModel device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -38,8 +38,7 @@ class DCRNNSupervisor: self.horizon = int(self._model_kwargs.get('horizon', 1)) # for the decoder # setup model - self.encoder_model = EncoderModel(True, adj_mx, **self._model_kwargs) - self.decoder_model = DecoderModel(True, adj_mx, **self._model_kwargs) + self.dcrnn_model = DCRNNModel(adj_mx, self._logger, **self._model_kwargs) @staticmethod def _get_log_dir(kwargs): @@ -73,76 +72,12 @@ class DCRNNSupervisor: kwargs.update(self._train_kwargs) return self._train(**kwargs) - def _train_one_batch(self, inputs, labels, batches_seen, encoder_optimizer, - decoder_optimizer, criterion): - """ - - :param inputs: shape (seq_len, batch_size, num_sensor, input_dim) - :param labels: shape (horizon, batch_size, num_sensor, input_dim) - :param encoder_optimizer: - :param decoder_optimizer: - :param criterion: minimize this criterion - :return: loss? - """ - - encoder_optimizer.zero_grad() - decoder_optimizer.zero_grad() - - batch_size = inputs.size(1) - - inputs = inputs.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) - labels = labels[..., :self.output_dim].view(self.horizon, batch_size, - self.num_nodes * self.output_dim) - - loss = 0 - - encoder_hidden_state = None - for t in range(self.seq_len): - _, encoder_hidden_state = self.encoder_model.forward(inputs[t], encoder_hidden_state) - - self._logger.info("Encoder complete, starting decoder") - go_symbol = torch.zeros((batch_size, self.num_nodes * self.output_dim)) - - decoder_hidden_state = encoder_hidden_state - decoder_input = go_symbol - - outputs = [] - - for t in range(self.horizon): - decoder_output, decoder_hidden_state = self.decoder_model.forward(decoder_input, - decoder_hidden_state) - decoder_input = decoder_output - - outputs.append(decoder_output) - - if self.use_curriculum_learning: # todo check for is_training (pytorch way?) - c = np.random.uniform(0, 1) - if c < self._compute_sampling_threshold(batches_seen): - decoder_input = labels[t] - - loss += criterion(self.standard_scaler.inverse_transform(decoder_output), - self.standard_scaler.inverse_transform(labels[t])) - - self._logger.info("Decoder complete, starting backprop") - loss.backward() - - # gradient clipping - this does it in place - torch.nn.utils.clip_grad_norm_(self.encoder_model.parameters(), self.max_grad_norm) - torch.nn.utils.clip_grad_norm_(self.decoder_model.parameters(), self.max_grad_norm) - - encoder_optimizer.step() - decoder_optimizer.step() - - outputs = torch.stack(outputs) - return outputs.view(self.horizon, batch_size, self.num_nodes, self.output_dim), loss.item() - def _train(self, base_lr, steps, patience=50, epochs=100, min_learning_rate=2e-6, lr_decay_ratio=0.1, log_every=10, save_model=1, test_every_n_epochs=10, **kwargs): # steps is used in learning rate - will see if need to use it? - encoder_optimizer = torch.optim.Adam(self.encoder_model.parameters(), lr=base_lr) - decoder_optimizer = torch.optim.Adam(self.encoder_model.parameters(), lr=base_lr) + optimizer = torch.optim.Adam(self.dcrnn_model.parameters(), lr=base_lr) criterion = torch.nn.L1Loss() # mae loss batches_seen = 0 @@ -154,16 +89,23 @@ class DCRNNSupervisor: start_time = time.time() for _, (x, y) in enumerate(train_iterator): - x = torch.from_numpy(x).float() - y = torch.from_numpy(y).float() - self._logger.debug("X: {}".format(x.size())) - self._logger.debug("y: {}".format(y.size())) - x = x.permute(1, 0, 2, 3) - y = y.permute(1, 0, 2, 3) - output, loss = self._train_one_batch(x, y, batches_seen, encoder_optimizer, - decoder_optimizer, criterion) - losses.append(loss) + optimizer.zero_grad() + + x, y = self._get_x_y(x, y) + x, y = self._get_x_y_in_correct_dims(x, y) + + output = self.dcrnn_model(x, y) + loss = self._compute_loss(y, output, criterion) + self._logger.info(loss.item()) + losses.append(loss.item()) + batches_seen += 1 + loss.backward() + + # gradient clipping - this does it in place + torch.nn.utils.clip_grad_norm_(self.dcrnn_model.parameters(), self.max_grad_norm) + + optimizer.step() end_time = time.time() if epoch_num % log_every == 0: @@ -173,6 +115,41 @@ class DCRNNSupervisor: 0.0, (end_time - start_time)) self._logger.info(message) + def _get_x_y(self, x, y): + """ + :param x: shape (batch_size, seq_len, num_sensor, input_dim) + :param y: shape (batch_size, horizon, num_sensor, input_dim) + :returns x shape (seq_len, batch_size, num_sensor, input_dim) + y shape (horizon, batch_size, num_sensor, input_dim) + """ + x = torch.from_numpy(x).float() + y = torch.from_numpy(y).float() + self._logger.debug("X: {}".format(x.size())) + self._logger.debug("y: {}".format(y.size())) + x = x.permute(1, 0, 2, 3) + y = y.permute(1, 0, 2, 3) + return x, y + + def _get_x_y_in_correct_dims(self, x, y): + """ + :param x: shape (seq_len, batch_size, num_sensor, input_dim) + :param y: shape (horizon, batch_size, num_sensor, input_dim) + :return: x: shape (seq_len, batch_size, num_sensor * input_dim) + y: shape (horizon, batch_size, num_sensor * output_dim) + """ + batch_size = x.size(1) + x = x.view(self.seq_len, batch_size, self.num_nodes * self.input_dim) + y = y[..., :self.output_dim].view(self.horizon, batch_size, + self.num_nodes * self.output_dim) + return x, y + def _compute_sampling_threshold(self, batches_seen): return self.cl_decay_steps / ( self.cl_decay_steps + np.exp(batches_seen / self.cl_decay_steps)) + + def _compute_loss(self, y_true, y_predicted, criterion): + loss = 0 + for t in range(self.horizon): + loss += criterion(self.standard_scaler.inverse_transform(y_predicted[t]), + self.standard_scaler.inverse_transform(y_true[t])) + return loss