144 lines
4.9 KiB
Python
Executable File
144 lines
4.9 KiB
Python
Executable File
import torch
|
|
from torch import nn
|
|
|
|
from model.STID.MLP import MultiLayerPerceptron
|
|
|
|
|
|
class STID(nn.Module):
|
|
"""
|
|
Paper: Spatial-Temporal Identity: A Simple yet Effective Baseline for Multivariate Time Series Forecasting
|
|
Link: https://arxiv.org/abs/2208.05233
|
|
Official Code: https://github.com/zezhishao/STID
|
|
"""
|
|
|
|
def __init__(self, model_args):
|
|
super().__init__()
|
|
# attributes
|
|
self.num_nodes = model_args["num_nodes"]
|
|
self.node_dim = model_args["node_dim"]
|
|
self.input_len = model_args["input_len"]
|
|
self.input_dim = model_args["input_dim"]
|
|
self.embed_dim = model_args["embed_dim"]
|
|
self.output_len = model_args["output_len"]
|
|
self.num_layer = model_args["num_layer"]
|
|
self.temp_dim_tid = model_args["temp_dim_tid"]
|
|
self.temp_dim_diw = model_args["temp_dim_diw"]
|
|
self.time_of_day_size = model_args["time_of_day_size"]
|
|
self.day_of_week_size = model_args["day_of_week_size"]
|
|
|
|
self.if_time_in_day = model_args["if_T_i_D"]
|
|
self.if_day_in_week = model_args["if_D_i_W"]
|
|
self.if_spatial = model_args["if_node"]
|
|
|
|
# spatial embeddings
|
|
if self.if_spatial:
|
|
self.node_emb = nn.Parameter(torch.empty(self.num_nodes, self.node_dim))
|
|
nn.init.xavier_uniform_(self.node_emb)
|
|
# temporal embeddings
|
|
if self.if_time_in_day:
|
|
self.time_in_day_emb = nn.Parameter(
|
|
torch.empty(self.time_of_day_size, self.temp_dim_tid)
|
|
)
|
|
nn.init.xavier_uniform_(self.time_in_day_emb)
|
|
if self.if_day_in_week:
|
|
self.day_in_week_emb = nn.Parameter(
|
|
torch.empty(self.day_of_week_size, self.temp_dim_diw)
|
|
)
|
|
nn.init.xavier_uniform_(self.day_in_week_emb)
|
|
|
|
# embedding layer
|
|
self.time_series_emb_layer = nn.Conv2d(
|
|
in_channels=self.input_dim * self.input_len,
|
|
out_channels=self.embed_dim,
|
|
kernel_size=(1, 1),
|
|
bias=True,
|
|
)
|
|
|
|
# encoding
|
|
self.hidden_dim = (
|
|
self.embed_dim
|
|
+ self.node_dim * int(self.if_spatial)
|
|
+ self.temp_dim_tid * int(self.if_day_in_week)
|
|
+ self.temp_dim_diw * int(self.if_time_in_day)
|
|
)
|
|
self.encoder = nn.Sequential(
|
|
*[
|
|
MultiLayerPerceptron(self.hidden_dim, self.hidden_dim)
|
|
for _ in range(self.num_layer)
|
|
]
|
|
)
|
|
|
|
# regression
|
|
self.regression_layer = nn.Conv2d(
|
|
in_channels=self.hidden_dim,
|
|
out_channels=self.output_len,
|
|
kernel_size=(1, 1),
|
|
bias=True,
|
|
)
|
|
|
|
def forward(self, history_data: torch.Tensor) -> torch.Tensor:
|
|
"""Feed forward of STID.
|
|
|
|
Args:
|
|
history_data (torch.Tensor): history data with shape [B, L, N, C]
|
|
|
|
Returns:
|
|
torch.Tensor: prediction with shape [B, L, N, C]
|
|
"""
|
|
|
|
# prepare data
|
|
input_data = history_data[..., range(self.input_dim)]
|
|
# input_data = history_data[..., 0:1]
|
|
|
|
if self.if_time_in_day:
|
|
t_i_d_data = history_data[..., 1]
|
|
# In the datasets used in STID, the time_of_day feature is normalized to [0, 1]. We multiply it by 288 to get the index.
|
|
# If you use other datasets, you may need to change this line.
|
|
time_in_day_emb = self.time_in_day_emb[
|
|
(t_i_d_data[:, -1, :] * self.time_of_day_size).type(torch.LongTensor)
|
|
]
|
|
else:
|
|
time_in_day_emb = None
|
|
if self.if_day_in_week:
|
|
d_i_w_data = history_data[..., 2]
|
|
day_in_week_emb = self.day_in_week_emb[
|
|
(d_i_w_data[:, -1, :] * self.day_of_week_size).type(torch.LongTensor)
|
|
]
|
|
else:
|
|
day_in_week_emb = None
|
|
|
|
# time series embedding
|
|
batch_size, _, num_nodes, _ = input_data.shape
|
|
input_data = input_data.transpose(1, 2).contiguous()
|
|
input_data = (
|
|
input_data.view(batch_size, num_nodes, -1).transpose(1, 2).unsqueeze(-1)
|
|
)
|
|
time_series_emb = self.time_series_emb_layer(input_data)
|
|
|
|
node_emb = []
|
|
if self.if_spatial:
|
|
# expand node embeddings
|
|
node_emb.append(
|
|
self.node_emb.unsqueeze(0)
|
|
.expand(batch_size, -1, -1)
|
|
.transpose(1, 2)
|
|
.unsqueeze(-1)
|
|
)
|
|
# temporal embeddings
|
|
tem_emb = []
|
|
if time_in_day_emb is not None:
|
|
tem_emb.append(time_in_day_emb.transpose(1, 2).unsqueeze(-1))
|
|
if day_in_week_emb is not None:
|
|
tem_emb.append(day_in_week_emb.transpose(1, 2).unsqueeze(-1))
|
|
|
|
# concate all embeddings
|
|
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
|
|
|
|
# encoding
|
|
hidden = self.encoder(hidden)
|
|
|
|
# regression
|
|
prediction = self.regression_layer(hidden)
|
|
|
|
return prediction
|