51 lines
1.6 KiB
Python
Executable File
51 lines
1.6 KiB
Python
Executable File
"""Define the Layers"""
|
|
|
|
import torch.nn as nn
|
|
from model.DSANET.SubLayers import MultiHeadAttention, PositionwiseFeedForward
|
|
|
|
|
|
class EncoderLayer(nn.Module):
|
|
"""Compose with two layers"""
|
|
|
|
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
|
super(EncoderLayer, self).__init__()
|
|
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
|
|
|
def forward(self, enc_input):
|
|
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input)
|
|
|
|
enc_output = self.pos_ffn(enc_output)
|
|
|
|
return enc_output, enc_slf_attn
|
|
|
|
|
|
class DecoderLayer(nn.Module):
|
|
"""Compose with three layers"""
|
|
|
|
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
|
|
super(DecoderLayer, self).__init__()
|
|
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
|
|
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
|
|
|
|
def forward(
|
|
self,
|
|
dec_input,
|
|
enc_output,
|
|
non_pad_mask=None,
|
|
slf_attn_mask=None,
|
|
dec_enc_attn_mask=None,
|
|
):
|
|
dec_output, dec_slf_attn = self.slf_attn(
|
|
dec_input, dec_input, dec_input, mask=slf_attn_mask
|
|
)
|
|
|
|
dec_output, dec_enc_attn = self.enc_attn(
|
|
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask
|
|
)
|
|
|
|
dec_output = self.pos_ffn(dec_output)
|
|
|
|
return dec_output, dec_slf_attn, dec_enc_attn
|