41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
from torch.nn import Linear, ModuleList
|
|
from torch.nn import BatchNorm1d, Identity
|
|
|
|
|
|
class MLP(torch.nn.Module):
|
|
"""
|
|
Multilayer Perceptron
|
|
"""
|
|
def __init__(self,
|
|
channel_list,
|
|
dropout=0.,
|
|
batch_norm=True,
|
|
relu_first=False):
|
|
super().__init__()
|
|
assert len(channel_list) >= 2
|
|
self.channel_list = channel_list
|
|
self.dropout = dropout
|
|
self.relu_first = relu_first
|
|
|
|
self.linears = ModuleList()
|
|
self.norms = ModuleList()
|
|
for in_channel, out_channel in zip(channel_list[:-1],
|
|
channel_list[1:]):
|
|
self.linears.append(Linear(in_channel, out_channel))
|
|
self.norms.append(
|
|
BatchNorm1d(out_channel) if batch_norm else Identity())
|
|
|
|
def forward(self, x):
|
|
x = self.linears[0](x)
|
|
for layer, norm in zip(self.linears[1:], self.norms[:-1]):
|
|
if self.relu_first:
|
|
x = F.relu(x)
|
|
x = norm(x)
|
|
if not self.relu_first:
|
|
x = F.relu(x)
|
|
x = F.dropout(x, p=self.dropout, training=self.training)
|
|
x = layer.forward(x)
|
|
return x
|