41 lines
1.2 KiB
Python
41 lines
1.2 KiB
Python
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
class LSTM(nn.Module):
|
|
def __init__(self,
|
|
in_channels,
|
|
hidden,
|
|
out_channels,
|
|
n_layers=2,
|
|
embed_size=8,
|
|
dropout=.0):
|
|
super(LSTM, self).__init__()
|
|
self.in_channels = in_channels
|
|
self.hidden = hidden
|
|
self.embed_size = embed_size
|
|
self.out_channels = out_channels
|
|
self.n_layers = n_layers
|
|
|
|
self.encoder = nn.Embedding(in_channels, embed_size)
|
|
|
|
self.rnn =\
|
|
nn.LSTM(
|
|
input_size=embed_size if embed_size else in_channels,
|
|
hidden_size=hidden,
|
|
num_layers=n_layers,
|
|
batch_first=True,
|
|
dropout=dropout
|
|
)
|
|
|
|
self.decoder = nn.Linear(hidden, out_channels)
|
|
|
|
def forward(self, input_):
|
|
if self.embed_size:
|
|
input_ = self.encoder(input_)
|
|
output, _ = self.rnn(input_)
|
|
output = self.decoder(output)
|
|
output = output.permute(0, 2, 1) # change dimension to (B, C, T)
|
|
final_word = output[:, :, -1]
|
|
return final_word
|