TrafficWheel/model/STID/STID.py

110 lines
4.1 KiB
Python
Executable File

import torch
from torch import nn
from model.STID.MLP import MultiLayerPerceptron
class STID(nn.Module):
def __init__(self, model_args):
super().__init__()
self.num_nodes = model_args["num_nodes"]
self.node_dim = model_args["node_dim"]
self.input_len = model_args["input_len"]
self.input_dim = model_args["input_dim"]
self.embed_dim = model_args["embed_dim"]
self.output_len = model_args["output_len"]
self.num_layer = model_args["num_layer"]
self.temp_dim_tid = model_args["temp_dim_tid"]
self.temp_dim_diw = model_args["temp_dim_diw"]
self.time_of_day_size = model_args["time_of_day_size"]
self.day_of_week_size = model_args["day_of_week_size"]
self.if_time_in_day = model_args["if_T_i_D"]
self.if_day_in_week = model_args["if_D_i_W"]
self.if_spatial = model_args["if_node"]
if self.if_spatial:
self.node_emb = nn.Parameter(torch.empty(self.num_nodes, self.node_dim))
nn.init.xavier_uniform_(self.node_emb)
if self.if_time_in_day:
self.time_in_day_emb = nn.Parameter(
torch.empty(self.time_of_day_size, self.temp_dim_tid)
)
nn.init.xavier_uniform_(self.time_in_day_emb)
if self.if_day_in_week:
self.day_in_week_emb = nn.Parameter(
torch.empty(self.day_of_week_size, self.temp_dim_diw)
)
nn.init.xavier_uniform_(self.day_in_week_emb)
self.time_series_emb_layer = nn.Conv2d(
in_channels=self.input_dim * self.input_len,
out_channels=self.embed_dim,
kernel_size=(1, 1),
bias=True,
)
self.hidden_dim = (
self.embed_dim
+ self.node_dim * int(self.if_spatial)
+ self.temp_dim_tid * int(self.if_time_in_day)
+ self.temp_dim_diw * int(self.if_day_in_week)
)
self.encoder = nn.Sequential(
*[MultiLayerPerceptron(self.hidden_dim, self.hidden_dim) for _ in range(self.num_layer)]
)
self.regression_layer = nn.Conv2d(
in_channels=self.hidden_dim,
out_channels=self.output_len,
kernel_size=(1, 1),
bias=True,
)
def forward(self, history_data: torch.Tensor) -> torch.Tensor:
device = history_data.device
input_data = history_data[..., range(self.input_dim)]
if self.if_time_in_day:
t_i_d_data = history_data[..., 1]
idx_tid = (t_i_d_data[:, -1, :] * self.time_of_day_size).long()
idx_tid = torch.clamp(idx_tid, 0, self.time_of_day_size - 1).to(device)
time_in_day_emb = self.time_in_day_emb[idx_tid]
else:
time_in_day_emb = None
if self.if_day_in_week:
d_i_w_data = history_data[..., 2]
idx_diw = (d_i_w_data[:, -1, :] * self.day_of_week_size).long()
idx_diw = torch.clamp(idx_diw, 0, self.day_of_week_size - 1).to(device)
day_in_week_emb = self.day_in_week_emb[idx_diw]
else:
day_in_week_emb = None
B, L, N, C = input_data.shape
x = input_data.permute(0, 3, 1, 2).reshape(B, L * C, 1, N)
x = x.to(device)
time_series_emb = self.time_series_emb_layer(x) # [B, E, 1, N]
node_emb = []
if self.if_spatial:
node_emb.append(
self.node_emb.unsqueeze(0)
.expand(B, -1, -1)
.transpose(1, 2)
.unsqueeze(2) # ✅ [B, Dn, 1, N]
)
tem_emb = []
if time_in_day_emb is not None:
tem_emb.append(time_in_day_emb.transpose(1, 2).unsqueeze(2)) # [B, Dt, 1, N]
if day_in_week_emb is not None:
tem_emb.append(day_in_week_emb.transpose(1, 2).unsqueeze(2)) # [B, Dw, 1, N]
hidden = torch.cat([time_series_emb] + node_emb + tem_emb, dim=1)
hidden = self.encoder(hidden)
prediction = self.regression_layer(hidden)
prediction = prediction.permute(0, 1, 3, 2) # [B, t, n, c]
return prediction # [B, t, n, c]