29 lines
1.0 KiB
Python
29 lines
1.0 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
def __init__(self, d_model, patch_len, stride, padding, dropout):
|
|
super(PatchEmbedding, self).__init__()
|
|
# Patching
|
|
self.patch_len = patch_len
|
|
self.stride = stride
|
|
self.padding_patch_layer = nn.ReplicationPad1d((0, padding))
|
|
|
|
# Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
|
|
self.value_embedding = nn.Linear(patch_len, d_model, bias=False)
|
|
|
|
# Positional embedding
|
|
self.position_embedding = PositionalEmbedding(d_model)
|
|
|
|
# Residual dropout
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
# do patching
|
|
n_vars = x.shape[1]
|
|
x = self.padding_patch_layer(x)
|
|
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
|
x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
|
|
# Input encoding
|
|
x = self.value_embedding(x) + self.position_embedding(x)
|
|
return self.dropout(x), n_vars |