43 lines
1.6 KiB
Python
43 lines
1.6 KiB
Python
''' Define the Layers '''
|
|
import torch.nn as nn
|
|
from model.DSANET.SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
''' Compose with two layers '''
|
|
|
|
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
|
super(EncoderLayer, self).__init__()
|
|
self.slf_attn = MultiHeadAttention(
|
|
n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
|
|
|
def forward(self, enc_input):
|
|
enc_output, enc_slf_attn = self.slf_attn(
|
|
enc_input, enc_input, enc_input)
|
|
|
|
enc_output = self.pos_ffn(enc_output)
|
|
|
|
return enc_output, enc_slf_attn
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
''' Compose with three layers '''
|
|
|
|
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
|
super(DecoderLayer, self).__init__()
|
|
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
|
|
|
def forward(self, dec_input, enc_output, non_pad_mask=None, slf_attn_mask=None, dec_enc_attn_mask=None):
|
|
dec_output, dec_slf_attn = self.slf_attn(
|
|
dec_input, dec_input, dec_input, mask=slf_attn_mask)
|
|
|
|
dec_output, dec_enc_attn = self.enc_attn(
|
|
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask)
|
|
|
|
dec_output = self.pos_ffn(dec_output)
|
|
|
|
return dec_output, dec_slf_attn, dec_enc_attn
|