import torch from torch import nn from timm.models.vision_transformer import trunc_normal_ from .tsformer_components.patch import PatchEmbedding from .tsformer_components.mask import MaskGenerator from .tsformer_components.positional_encoding import PositionalEncoding from .tsformer_components.transformer_layers import TransformerLayers def unshuffle(shuffled_tokens): dic = {} for k, v, in enumerate(shuffled_tokens): dic[v] = k unshuffle_index = [] for i in range(len(shuffled_tokens)): unshuffle_index.append(dic[i]) return unshuffle_index class TSFormer(nn.Module): """An efficient unsupervised pre-training model for Time Series based on transFormer blocks. (TSFormer)""" def __init__(self, patch_size, in_channel, embed_dim, num_heads, mlp_ratio, dropout, num_token, mask_ratio, encoder_depth, decoder_depth, mode="pre-train"): super().__init__() assert mode in ["pre-train", "forecasting"], "Error mode." self.patch_size = patch_size self.in_channel = in_channel self.embed_dim = embed_dim self.num_heads = num_heads self.num_token = num_token self.mask_ratio = mask_ratio self.encoder_depth = encoder_depth self.mode = mode self.mlp_ratio = mlp_ratio self.selected_feature = 0 # norm layers self.encoder_norm = nn.LayerNorm(embed_dim) self.decoder_norm = nn.LayerNorm(embed_dim) # encoder specifics # # patchify & embedding self.patch_embedding = PatchEmbedding(patch_size, in_channel, embed_dim, norm_layer=None) # # positional encoding self.positional_encoding = PositionalEncoding(embed_dim, dropout=dropout) # # masking self.mask = MaskGenerator(num_token, mask_ratio) # encoder self.encoder = TransformerLayers(embed_dim, encoder_depth, mlp_ratio, num_heads, dropout) # decoder specifics # transform layer self.enc_2_dec_emb = nn.Linear(embed_dim, embed_dim, bias=True) # # mask token self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim)) # # decoder self.decoder = TransformerLayers(embed_dim, decoder_depth, mlp_ratio, num_heads, dropout) # # prediction (reconstruction) layer self.output_layer = nn.Linear(embed_dim, patch_size) self.initialize_weights() def initialize_weights(self): # positional encoding nn.init.uniform_(self.positional_encoding.position_embedding, -.02, .02) # mask token trunc_normal_(self.mask_token, std=.02) def encoding(self, long_term_history, mask=True): """Encoding process of TSFormer: patchify, positional encoding, mask, Transformer layers. Args: long_term_history (torch.Tensor): Very long-term historical MTS with shape [B, N, 1, P * L], which is used in the TSFormer. P is the number of segments (patches). mask (bool): True in pre-training stage and False in forecasting stage. Returns: torch.Tensor: hidden states of unmasked tokens list: unmasked token index list: masked token index """ batch_size, num_nodes, _, _ = long_term_history.shape # patchify and embed input patches = self.patch_embedding(long_term_history) # B, N, d, P patches = patches.transpose(-1, -2) # B, N, P, d # positional embedding patches = self.positional_encoding(patches) # mask if mask: unmasked_token_index, masked_token_index = self.mask() encoder_input = patches[:, :, unmasked_token_index, :] else: unmasked_token_index, masked_token_index = None, None encoder_input = patches # encoding hidden_states_unmasked = self.encoder(encoder_input) hidden_states_unmasked = self.encoder_norm(hidden_states_unmasked).view(batch_size, num_nodes, -1, self.embed_dim) return hidden_states_unmasked, unmasked_token_index, masked_token_index def decoding(self, hidden_states_unmasked, masked_token_index): """Decoding process of TSFormer: encoder 2 decoder layer, add mask tokens, Transformer layers, predict. Args: hidden_states_unmasked (torch.Tensor): hidden states of masked tokens [B, N, P*(1-r), d]. masked_token_index (list): masked token index Returns: torch.Tensor: reconstructed data """ batch_size, num_nodes, _, _ = hidden_states_unmasked.shape # encoder 2 decoder layer hidden_states_unmasked = self.enc_2_dec_emb(hidden_states_unmasked) # add mask tokens hidden_states_masked = self.positional_encoding( self.mask_token.expand(batch_size, num_nodes, len(masked_token_index), hidden_states_unmasked.shape[-1]), index=masked_token_index ) hidden_states_full = torch.cat([hidden_states_unmasked, hidden_states_masked], dim=-2) # B, N, P, d # decoding hidden_states_full = self.decoder(hidden_states_full) hidden_states_full = self.decoder_norm(hidden_states_full) # prediction (reconstruction) reconstruction_full = self.output_layer(hidden_states_full.view(batch_size, num_nodes, -1, self.embed_dim)) return reconstruction_full def get_reconstructed_masked_tokens(self, reconstruction_full, real_value_full, unmasked_token_index, masked_token_index): """Get reconstructed masked tokens and corresponding ground-truth for subsequent loss computing. Args: reconstruction_full (torch.Tensor): reconstructed full tokens. real_value_full (torch.Tensor): ground truth full tokens. unmasked_token_index (list): unmasked token index. masked_token_index (list): masked token index. Returns: torch.Tensor: reconstructed masked tokens. torch.Tensor: ground truth masked tokens. """ # get reconstructed masked tokens batch_size, num_nodes, _, _ = reconstruction_full.shape reconstruction_masked_tokens = reconstruction_full[:, :, len(unmasked_token_index):, :] # B, N, r*P, d reconstruction_masked_tokens = reconstruction_masked_tokens.view(batch_size, num_nodes, -1).transpose(1, 2) # B, r*P*d, N label_full = real_value_full.permute(0, 3, 1, 2).unfold(1, self.patch_size, self.patch_size)[:, :, :, self.selected_feature, :].transpose(1, 2) # B, N, P, L label_masked_tokens = label_full[:, :, masked_token_index, :].contiguous() # B, N, r*P, d label_masked_tokens = label_masked_tokens.view(batch_size, num_nodes, -1).transpose(1, 2) # B, r*P*d, N return reconstruction_masked_tokens, label_masked_tokens def forward(self, history_data: torch.Tensor, future_data: torch.Tensor = None, batch_seen: int = None, epoch: int = None, **kwargs) -> torch.Tensor: """feed forward of the TSFormer. TSFormer has two modes: the pre-training mode and the forecasting mode, which are used in the pre-training stage and the forecasting stage, respectively. Args: history_data (torch.Tensor): very long-term historical time series with shape B, L * P, N, 1. Returns: pre-training: torch.Tensor: the reconstruction of the masked tokens. Shape [B, L * P * r, N, 1] torch.Tensor: the ground truth of the masked tokens. Shape [B, L * P * r, N, 1] dict: data for plotting. forecasting: torch.Tensor: the output of TSFormer of the encoder with shape [B, N, L, 1]. """ # reshape history_data = history_data.permute(0, 2, 3, 1) # B, N, 1, L * P # feed forward if self.mode == "pre-train": # encoding hidden_states_unmasked, unmasked_token_index, masked_token_index = self.encoding(history_data) # decoding reconstruction_full = self.decoding(hidden_states_unmasked, masked_token_index) # for subsequent loss computing reconstruction_masked_tokens, label_masked_tokens = self.get_reconstructed_masked_tokens(reconstruction_full, history_data, unmasked_token_index, masked_token_index) return reconstruction_masked_tokens, label_masked_tokens else: hidden_states_full, _, _ = self.encoding(history_data, mask=False) return hidden_states_full