From 6386ac7eb4f4dede91eb570c7cfe6bc62e0837c0 Mon Sep 17 00:00:00 2001 From: Chintan Shah Date: Sun, 29 Sep 2019 17:40:52 -0400 Subject: [PATCH] Implemented encoder using GRUCell instead so that it's easier to swap that with DCGRUCell --- model/pytorch/dcrnn_model.py | 46 ++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 0403a14..bc36567 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -30,12 +30,44 @@ class EncoderModel(DCRNNModel): # https://pytorch.org/docs/stable/nn.html#gru - # since input shape is Input (batch_size, timesteps, num_sensor*input_dim),batch_first=True - self.dcgru = nn.GRU(input_size=self.num_nodes * self.input_dim, - hidden_size=self.rnn_units, - num_layers=self.num_rnn_layers, - batch_first=True) + # input shape is supposed to be Input (batch_size, timesteps, num_sensor*input_dim) + # first layer takes input shape and subsequent layer take input from the first layer + self.dcgru_layers = [nn.GRUCell(input_size=self.num_nodes * self.input_dim, + hidden_size=self.rnn_units, + bias=True)] + [nn.GRUCell(input_size=self.rnn_units, + hidden_size=self.rnn_units, + bias=True) for _ in + range(self.num_rnn_layers - 1)] def forward(self, inputs, hidden_state=None): - # is None okay? - return self.dcgru(inputs, hidden_state) + """ + Encoder forward pass. + + :param inputs: shape (batch_size, timesteps, num_sensor*input_dim) + :param hidden_state: (num_layers, batch_size, rnn_units) -> optional, zeros if not provided + :return: output, hidden_state + """ + layer_input = inputs.permute(1, 0, 2) # first axis is now timesteps + if hidden_state is None: + batch_size = inputs.size()[0] + hidden_state = torch.zeros((self.num_rnn_layers, batch_size, self.rnn_units), + device=device) + hidden = torch.empty_like(hidden_state) + for layer_num, dcgru_layer in enumerate(self.dcgru_layers): + layer_states = self._forward_layer(layer_input, dcgru_layer, hidden_state[layer_num]) + # append last time step's hidden state + hidden[layer_num] = layer_states[-1] + layer_input = layer_states + + output = layer_input # last layer's output + return output, hidden + + @staticmethod + def _forward_layer(inputs, dcgru_layer, hidden_state): + # inputs shape = (timesteps, batch_size, input_size) + outputs = [] # shape (timesteps, batch_size, self.rnn_units) + for cell_input in inputs[:, ]: + hidden_state = dcgru_layer(cell_input, hidden_state) + outputs.append(hidden_state) + + return torch.cat(outputs, dim=1) # runs in O(timesteps) not too slow