import torch from torch import nn from model.STID.MLP import MultiLayerPerceptron class STID(nn.Module): """ Paper: Spatial-Temporal Identity: A Simple yet Effective Baseline for Multivariate Time Series Forecasting Link: https://arxiv.org/abs/2208.05233 Official Code: https://github.com/zezhishao/STID """ def __init__(self, model_args): super().__init__() # attributes self.num_nodes = model_args["num_nodes"] self.node_dim = model_args["node_dim"] self.input_len = model_args["input_len"] self.input_dim = model_args["input_dim"] self.embed_dim = model_args["embed_dim"] self.output_len = model_args["output_len"] self.num_layer = model_args["num_layer"] self.temp_dim_tid = model_args["temp_dim_tid"] self.temp_dim_diw = model_args["temp_dim_diw"] self.time_of_day_size = model_args["time_of_day_size"] self.day_of_week_size = model_args["day_of_week_size"] self.if_time_in_day = model_args["if_T_i_D"] self.if_day_in_week = model_args["if_D_i_W"] self.if_spatial = model_args["if_node"] # spatial embeddings if self.if_spatial: self.node_emb = nn.Parameter(torch.empty(self.num_nodes, self.node_dim)) nn.init.xavier_uniform_(self.node_emb) # temporal embeddings if self.if_time_in_day: self.time_in_day_emb = nn.Parameter( torch.empty(self.time_of_day_size, self.temp_dim_tid) ) nn.init.xavier_uniform_(self.time_in_day_emb) if self.if_day_in_week: self.day_in_week_emb = nn.Parameter( torch.empty(self.day_of_week_size, self.temp_dim_diw) ) nn.init.xavier_uniform_(self.day_in_week_emb) # embedding layer self.time_series_emb_layer = nn.Conv2d( in_channels=self.input_dim * self.input_len, out_channels=self.embed_dim, kernel_size=(1, 1), bias=True, ) # encoding self.hidden_dim = ( self.embed_dim + self.node_dim * int(self.if_spatial) + self.temp_dim_tid * int(self.if_day_in_week) + self.temp_dim_diw * int(self.if_time_in_day) ) self.encoder = nn.Sequential( *[ MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer) ] ) # regression self.regression_layer = nn.Conv2d( in_channels=self.hidden_dim, out_channels=self.output_len, kernel_size=(1, 1), bias=True, ) def forward(self, history_data: torch.Tensor) -> torch.Tensor: """Feed forward of STID. Args: history_data (torch.Tensor): history data with shape [B, L, N, C] Returns: torch.Tensor: prediction with shape [B, L, N, C] """ # prepare data input_data = history_data[..., range(self.input_dim)] # input_data = history_data[..., 0:1] if self.if_time_in_day: t_i_d_data = history_data[..., 1] # In the datasets used in STID, the time_of_day feature is normalized to [0, 1]. We multiply it by 288 to get the index. # If you use other datasets, you may need to change this line. time_in_day_emb = self.time_in_day_emb[ (t_i_d_data[:, -1, :] * self.time_of_day_size).type(torch.LongTensor) ] else: time_in_day_emb = None if self.if_day_in_week: d_i_w_data = history_data[..., 2] day_in_week_emb = self.day_in_week_emb[ (d_i_w_data[:, -1, :] * self.day_of_week_size).type(torch.LongTensor) ] else: day_in_week_emb = None # time series embedding batch_size, _, num_nodes, _ = input_data.shape input_data = input_data.transpose(1, 2).contiguous() input_data = ( input_data.view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1) ) time_series_emb = self.time_series_emb_layer(input_data) node_emb = [] if self.if_spatial: # expand node embeddings node_emb.append( self.node_emb.unsqueeze(0) .expand(batch_size, -1, -1) .transpose(1, 2) .unsqueeze(-1) ) # temporal embeddings tem_emb = [] if time_in_day_emb is not None: tem_emb.append(time_in_day_emb.transpose(1, 2).unsqueeze(-1)) if day_in_week_emb is not None: tem_emb.append(day_in_week_emb.transpose(1, 2).unsqueeze(-1)) # concate all embeddings hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1) # encoding hidden = self.encoder(hidden) # regression prediction = self.regression_layer(hidden) return prediction