43 lines
1.8 KiB
Python
43 lines
1.8 KiB
Python
from torch import nn
|
|
|
|
|
|
class PatchEmbedding(nn.Module):
|
|
"""Patchify time series."""
|
|
|
|
def __init__(self, patch_size, in_channel, embed_dim, norm_layer):
|
|
super().__init__()
|
|
self.output_channel = embed_dim
|
|
self.len_patch = patch_size # the L
|
|
self.input_channel = in_channel
|
|
self.output_channel = embed_dim
|
|
self.input_embedding = nn.Conv2d(
|
|
in_channel,
|
|
embed_dim,
|
|
kernel_size=(self.len_patch, 1),
|
|
stride=(self.len_patch, 1))
|
|
self.norm_layer = norm_layer if norm_layer is not None else nn.Identity()
|
|
|
|
def forward(self, long_term_history):
|
|
"""
|
|
Args:
|
|
long_term_history (torch.Tensor): Very long-term historical MTS with shape [B, N, 1, P * L],
|
|
which is used in the TSFormer.
|
|
P is the number of segments (patches).
|
|
|
|
Returns:
|
|
torch.Tensor: patchified time series with shape [B, N, d, P]
|
|
"""
|
|
|
|
batch_size, num_nodes, num_feat, len_time_series = long_term_history.shape
|
|
long_term_history = long_term_history.unsqueeze(-1) # B, N, C, L, 1
|
|
# B*N, C, L, 1
|
|
long_term_history = long_term_history.reshape(batch_size*num_nodes, num_feat, len_time_series, 1)
|
|
# B*N, d, L/P, 1
|
|
output = self.input_embedding(long_term_history)
|
|
# norm
|
|
output = self.norm_layer(output)
|
|
# reshape
|
|
output = output.squeeze(-1).view(batch_size, num_nodes, self.output_channel, -1) # B, N, d, P
|
|
assert output.shape[-1] == len_time_series / self.len_patch
|
|
return output
|