import torch from torch import nn class MultiLayerPerceptron(nn.Module): """Multi-Layer Perceptron with residual links.""" def __init__(self, input_dim, hidden_dim) -> None: super().__init__() self.fc1 = nn.Conv2d( in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True, ) self.fc2 = nn.Conv2d( in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True, ) self.act = nn.ReLU() self.drop = nn.Dropout(p=0.15) def forward(self, input_data: torch.Tensor) -> torch.Tensor: """Feed forward of MLP. Args: input_data (torch.Tensor): input data with shape [B, D, N] Returns: torch.Tensor: latent repr """ hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP hidden = hidden + input_data # residual return hidden