TrafficWheel/model/STEP/tsformer.py

192 lines
8.5 KiB
Python

import torch
from torch import nn
from timm.models.vision_transformer import trunc_normal_
from .tsformer_components.patch import PatchEmbedding
from .tsformer_components.mask import MaskGenerator
from .tsformer_components.positional_encoding import PositionalEncoding
from .tsformer_components.transformer_layers import TransformerLayers
def unshuffle(shuffled_tokens):
dic = {}
for k, v, in enumerate(shuffled_tokens):
dic[v] = k
unshuffle_index = []
for i in range(len(shuffled_tokens)):
unshuffle_index.append(dic[i])
return unshuffle_index
class TSFormer(nn.Module):
"""An efficient unsupervised pre-training model for Time Series based on transFormer blocks. (TSFormer)"""
def __init__(self, patch_size, in_channel, embed_dim, num_heads, mlp_ratio, dropout, num_token, mask_ratio, encoder_depth, decoder_depth, mode="pre-train"):
super().__init__()
assert mode in ["pre-train", "forecasting"], "Error mode."
self.patch_size = patch_size
self.in_channel = in_channel
self.embed_dim = embed_dim
self.num_heads = num_heads
self.num_token = num_token
self.mask_ratio = mask_ratio
self.encoder_depth = encoder_depth
self.mode = mode
self.mlp_ratio = mlp_ratio
self.selected_feature = 0
# norm layers
self.encoder_norm = nn.LayerNorm(embed_dim)
self.decoder_norm = nn.LayerNorm(embed_dim)
# encoder specifics
# # patchify & embedding
self.patch_embedding = PatchEmbedding(patch_size, in_channel, embed_dim, norm_layer=None)
# # positional encoding
self.positional_encoding = PositionalEncoding(embed_dim, dropout=dropout)
# # masking
self.mask = MaskGenerator(num_token, mask_ratio)
# encoder
self.encoder = TransformerLayers(embed_dim, encoder_depth, mlp_ratio, num_heads, dropout)
# decoder specifics
# transform layer
self.enc_2_dec_emb = nn.Linear(embed_dim, embed_dim, bias=True)
# # mask token
self.mask_token = nn.Parameter(torch.zeros(1, 1, 1, embed_dim))
# # decoder
self.decoder = TransformerLayers(embed_dim, decoder_depth, mlp_ratio, num_heads, dropout)
# # prediction (reconstruction) layer
self.output_layer = nn.Linear(embed_dim, patch_size)
self.initialize_weights()
def initialize_weights(self):
# positional encoding
nn.init.uniform_(self.positional_encoding.position_embedding, -.02, .02)
# mask token
trunc_normal_(self.mask_token, std=.02)
def encoding(self, long_term_history, mask=True):
"""Encoding process of TSFormer: patchify, positional encoding, mask, Transformer layers.
Args:
long_term_history (torch.Tensor): Very long-term historical MTS with shape [B, N, 1, P * L],
which is used in the TSFormer.
P is the number of segments (patches).
mask (bool): True in pre-training stage and False in forecasting stage.
Returns:
torch.Tensor: hidden states of unmasked tokens
list: unmasked token index
list: masked token index
"""
batch_size, num_nodes, _, _ = long_term_history.shape
# patchify and embed input
patches = self.patch_embedding(long_term_history) # B, N, d, P
patches = patches.transpose(-1, -2) # B, N, P, d
# positional embedding
patches = self.positional_encoding(patches)
# mask
if mask:
unmasked_token_index, masked_token_index = self.mask()
encoder_input = patches[:, :, unmasked_token_index, :]
else:
unmasked_token_index, masked_token_index = None, None
encoder_input = patches
# encoding
hidden_states_unmasked = self.encoder(encoder_input)
hidden_states_unmasked = self.encoder_norm(hidden_states_unmasked).view(batch_size, num_nodes, -1, self.embed_dim)
return hidden_states_unmasked, unmasked_token_index, masked_token_index
def decoding(self, hidden_states_unmasked, masked_token_index):
"""Decoding process of TSFormer: encoder 2 decoder layer, add mask tokens, Transformer layers, predict.
Args:
hidden_states_unmasked (torch.Tensor): hidden states of masked tokens [B, N, P*(1-r), d].
masked_token_index (list): masked token index
Returns:
torch.Tensor: reconstructed data
"""
batch_size, num_nodes, _, _ = hidden_states_unmasked.shape
# encoder 2 decoder layer
hidden_states_unmasked = self.enc_2_dec_emb(hidden_states_unmasked)
# add mask tokens
hidden_states_masked = self.positional_encoding(
self.mask_token.expand(batch_size, num_nodes, len(masked_token_index), hidden_states_unmasked.shape[-1]),
index=masked_token_index
)
hidden_states_full = torch.cat([hidden_states_unmasked, hidden_states_masked], dim=-2) # B, N, P, d
# decoding
hidden_states_full = self.decoder(hidden_states_full)
hidden_states_full = self.decoder_norm(hidden_states_full)
# prediction (reconstruction)
reconstruction_full = self.output_layer(hidden_states_full.view(batch_size, num_nodes, -1, self.embed_dim))
return reconstruction_full
def get_reconstructed_masked_tokens(self, reconstruction_full, real_value_full, unmasked_token_index, masked_token_index):
"""Get reconstructed masked tokens and corresponding ground-truth for subsequent loss computing.
Args:
reconstruction_full (torch.Tensor): reconstructed full tokens.
real_value_full (torch.Tensor): ground truth full tokens.
unmasked_token_index (list): unmasked token index.
masked_token_index (list): masked token index.
Returns:
torch.Tensor: reconstructed masked tokens.
torch.Tensor: ground truth masked tokens.
"""
# get reconstructed masked tokens
batch_size, num_nodes, _, _ = reconstruction_full.shape
reconstruction_masked_tokens = reconstruction_full[:, :, len(unmasked_token_index):, :] # B, N, r*P, d
reconstruction_masked_tokens = reconstruction_masked_tokens.view(batch_size, num_nodes, -1).transpose(1, 2) # B, r*P*d, N
label_full = real_value_full.permute(0, 3, 1, 2).unfold(1, self.patch_size, self.patch_size)[:, :, :, self.selected_feature, :].transpose(1, 2) # B, N, P, L
label_masked_tokens = label_full[:, :, masked_token_index, :].contiguous() # B, N, r*P, d
label_masked_tokens = label_masked_tokens.view(batch_size, num_nodes, -1).transpose(1, 2) # B, r*P*d, N
return reconstruction_masked_tokens, label_masked_tokens
def forward(self, history_data: torch.Tensor, future_data: torch.Tensor = None, batch_seen: int = None, epoch: int = None, **kwargs) -> torch.Tensor:
"""feed forward of the TSFormer.
TSFormer has two modes: the pre-training mode and the forecasting mode,
which are used in the pre-training stage and the forecasting stage, respectively.
Args:
history_data (torch.Tensor): very long-term historical time series with shape B, L * P, N, 1.
Returns:
pre-training:
torch.Tensor: the reconstruction of the masked tokens. Shape [B, L * P * r, N, 1]
torch.Tensor: the ground truth of the masked tokens. Shape [B, L * P * r, N, 1]
dict: data for plotting.
forecasting:
torch.Tensor: the output of TSFormer of the encoder with shape [B, N, L, 1].
"""
# reshape
history_data = history_data.permute(0, 2, 3, 1) # B, N, 1, L * P
# feed forward
if self.mode == "pre-train":
# encoding
hidden_states_unmasked, unmasked_token_index, masked_token_index = self.encoding(history_data)
# decoding
reconstruction_full = self.decoding(hidden_states_unmasked, masked_token_index)
# for subsequent loss computing
reconstruction_masked_tokens, label_masked_tokens = self.get_reconstructed_masked_tokens(reconstruction_full, history_data, unmasked_token_index, masked_token_index)
return reconstruction_masked_tokens, label_masked_tokens
else:
hidden_states_full, _, _ = self.encoding(history_data, mask=False)
return hidden_states_full