import torch.nn as nn class ResidualDecomp(nn.Module): """Residual decomposition.""" def __init__(self, input_shape): super().__init__() self.ln = nn.LayerNorm(input_shape[-1]) self.ac = nn.ReLU() def forward(self, x, y): u = x - self.ac(y) u = self.ln(u) return u