TrafficWheel/model/STEP/tsformer_components/transformer_layers.py

21 lines
770 B
Python

import math
from torch import nn
from torch.nn import TransformerEncoder, TransformerEncoderLayer
class TransformerLayers(nn.Module):
def __init__(self, hidden_dim, nlayers, mlp_ratio, num_heads=4, dropout=0.1):
super().__init__()
self.d_model = hidden_dim
encoder_layers = TransformerEncoderLayer(hidden_dim, num_heads, hidden_dim*mlp_ratio, dropout)
self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
def forward(self, src):
B, N, L, D = src.shape
src = src * math.sqrt(self.d_model)
src = src.view(B*N, L, D)
src = src.transpose(0, 1)
output = self.transformer_encoder(src, mask=None)
output = output.transpose(0, 1).view(B, N, L, D)
return output