TrafficWheel/model/D2STGNN/decouple/residual_decomp.py

16 lines
329 B
Python

import torch.nn as nn
class ResidualDecomp(nn.Module):
"""Residual decomposition."""
def __init__(self, input_shape):
super().__init__()
self.ln = nn.LayerNorm(input_shape[-1])
self.ac = nn.ReLU()
def forward(self, x, y):
u = x - self.ac(y)
u = self.ln(u)
return u