TrafficWheel/model/STEP/tsformer_components/patch.py

43 lines
1.8 KiB
Python

from torch import nn
class PatchEmbedding(nn.Module):
"""Patchify time series."""
def __init__(self, patch_size, in_channel, embed_dim, norm_layer):
super().__init__()
self.output_channel = embed_dim
self.len_patch = patch_size # the L
self.input_channel = in_channel
self.output_channel = embed_dim
self.input_embedding = nn.Conv2d(
in_channel,
embed_dim,
kernel_size=(self.len_patch, 1),
stride=(self.len_patch, 1))
self.norm_layer = norm_layer if norm_layer is not None else nn.Identity()
def forward(self, long_term_history):
"""
Args:
long_term_history (torch.Tensor): Very long-term historical MTS with shape [B, N, 1, P * L],
which is used in the TSFormer.
P is the number of segments (patches).
Returns:
torch.Tensor: patchified time series with shape [B, N, d, P]
"""
batch_size, num_nodes, num_feat, len_time_series = long_term_history.shape
long_term_history = long_term_history.unsqueeze(-1) # B, N, C, L, 1
# B*N, C, L, 1
long_term_history = long_term_history.reshape(batch_size*num_nodes, num_feat, len_time_series, 1)
# B*N, d, L/P, 1
output = self.input_embedding(long_term_history)
# norm
output = self.norm_layer(output)
# reshape
output = output.squeeze(-1).view(batch_size, num_nodes, self.output_channel, -1) # B, N, d, P
assert output.shape[-1] == len_time_series / self.len_patch
return output