21 lines
770 B
Python
21 lines
770 B
Python
import math
|
|
from torch import nn
|
|
from torch.nn import TransformerEncoder, TransformerEncoderLayer
|
|
|
|
|
|
class TransformerLayers(nn.Module):
|
|
def __init__(self, hidden_dim, nlayers, mlp_ratio, num_heads=4, dropout=0.1):
|
|
super().__init__()
|
|
self.d_model = hidden_dim
|
|
encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout)
|
|
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
|
|
|
|
def forward(self, src):
|
|
B, N, L, D = src.shape
|
|
src = src * math.sqrt(self.d_model)
|
|
src = src.view(B*N, L, D)
|
|
src = src.transpose(0, 1)
|
|
output = self.transformer_encoder(src, mask=None)
|
|
output = output.transpose(0, 1).view(B, N, L, D)
|
|
return output
|