36 lines
1.2 KiB
Python
36 lines
1.2 KiB
Python
import torch
|
|
from torch import nn
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
"""Positional encoding."""
|
|
|
|
def __init__(self, hidden_dim, dropout=0.1, max_len: int = 1000):
|
|
super().__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
self.position_embedding = nn.Parameter(torch.empty(max_len, hidden_dim), requires_grad=True)
|
|
|
|
def forward(self, input_data, index=None, abs_idx=None):
|
|
"""Positional encoding
|
|
|
|
Args:
|
|
input_data (torch.tensor): input sequence with shape [B, N, P, d].
|
|
index (list or None): add positional embedding by index.
|
|
|
|
Returns:
|
|
torch.tensor: output sequence
|
|
"""
|
|
|
|
batch_size, num_nodes, num_patches, num_feat = input_data.shape
|
|
input_data = input_data.view(batch_size*num_nodes, num_patches, num_feat)
|
|
# positional encoding
|
|
if index is None:
|
|
pe = self.position_embedding[:input_data.size(1), :].unsqueeze(0)
|
|
else:
|
|
pe = self.position_embedding[index].unsqueeze(0)
|
|
input_data = input_data + pe
|
|
input_data = self.dropout(input_data)
|
|
# reshape
|
|
input_data = input_data.view(batch_size, num_nodes, num_patches, num_feat)
|
|
return input_data
|