TrafficWheel/model/STID/MLP.py

30 lines
953 B
Python
Executable File

import torch
from torch import nn
class MultiLayerPerceptron(nn.Module):
"""Multi-Layer Perceptron with residual links."""
def __init__(self, input_dim, hidden_dim) -> None:
super().__init__()
self.fc1 = nn.Conv2d(
in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
self.fc2 = nn.Conv2d(
in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
self.act = nn.ReLU()
self.drop = nn.Dropout(p=0.15)
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
"""Feed forward of MLP.
Args:
input_data (torch.Tensor): input data with shape [B, D, N]
Returns:
torch.Tensor: latent repr
"""
hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP
hidden = hidden + input_data # residual
return hidden