Setup curriculum learning framework
This commit is contained in:
parent
bdce241a8f
commit
a1c9af2bad
|
|
@ -1,3 +1,4 @@
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
@ -122,6 +123,20 @@ class DecoderModel(nn.Module, DCRNNModel):
|
||||||
bias=True) for _ in
|
bias=True) for _ in
|
||||||
range(self.num_rnn_layers - 1)]
|
range(self.num_rnn_layers - 1)]
|
||||||
|
|
||||||
|
def t_step_forward_pass(self, hidden_state, inputs, output, t):
|
||||||
|
cell_input = inputs[:, t, :] # (batch_size, input_size)
|
||||||
|
|
||||||
|
if self.is_training:
|
||||||
|
if t > 0 and self.use_curriculum_learning:
|
||||||
|
c = np.random.uniform(0, 1)
|
||||||
|
if c >= self._compute_sampling_threshold(): #todo
|
||||||
|
cell_input = output[
|
||||||
|
t - 1] # todo: this won't work because the linear layer is applied after forward_impl
|
||||||
|
|
||||||
|
cell_output, hidden_state = self._forward_cell(cell_input, hidden_state)
|
||||||
|
output[t] = cell_output
|
||||||
|
return hidden_state
|
||||||
|
|
||||||
def forward(self, inputs, hidden_state=None):
|
def forward(self, inputs, hidden_state=None):
|
||||||
"""
|
"""
|
||||||
Decoder forward pass.
|
Decoder forward pass.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue