TrafficWheel/model/DSANET/Layers.py

51 lines
1.6 KiB
Python
Executable File

"""Define the Layers"""
import torch.nn as nn
from model.DSANET.SubLayers import MultiHeadAttention, PositionwiseFeedForward
class EncoderLayer(nn.Module):
"""Compose with two layers"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(EncoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(self, enc_input):
enc_output, enc_slf_attn = self.slf_attn(enc_input, enc_input, enc_input)
enc_output = self.pos_ffn(enc_output)
return enc_output, enc_slf_attn
class DecoderLayer(nn.Module):
"""Compose with three layers"""
def __init__(self, d_model, d_inner, n_head, d_k, d_v, dropout=0.1):
super(DecoderLayer, self).__init__()
self.slf_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.enc_attn = MultiHeadAttention(n_head, d_model, d_k, d_v, dropout=dropout)
self.pos_ffn = PositionwiseFeedForward(d_model, d_inner, dropout=dropout)
def forward(
self,
dec_input,
enc_output,
non_pad_mask=None,
slf_attn_mask=None,
dec_enc_attn_mask=None,
):
dec_output, dec_slf_attn = self.slf_attn(
dec_input, dec_input, dec_input, mask=slf_attn_mask
)
dec_output, dec_enc_attn = self.enc_attn(
dec_output, enc_output, enc_output, mask=dec_enc_attn_mask
)
dec_output = self.pos_ffn(dec_output)
return dec_output, dec_slf_attn, dec_enc_attn