diff --git a/model/pytorch/dcrnn_model.py b/model/pytorch/dcrnn_model.py index 560f7ca..7ac12df 100644 --- a/model/pytorch/dcrnn_model.py +++ b/model/pytorch/dcrnn_model.py @@ -153,7 +153,7 @@ class DCRNNModel(nn.Module, Seq2SeqAttrs): seq2seq forward pass :param inputs: shape (seq_len, batch_size, num_sensor * input_dim) :param labels: shape (horizon, batch_size, num_sensor * output) - :param batches_seen: batches seen till date + :param batches_seen: batches seen till now :return: output: (self.horizon, batch_size, self.num_nodes * self.output_dim) """ encoder_hidden_state = self.encoder(inputs)