import torch.nn as nn from model.D2STGNN.diffusion_block.forecast import Forecast from model.D2STGNN.diffusion_block.dif_model import STLocalizedConv from model.D2STGNN.decouple.residual_decomp import ResidualDecomp class DifBlock(nn.Module): def __init__(self, hidden_dim, forecast_hidden_dim=256, dy_graph=None, **model_args): """Diffusion block Args: hidden_dim (int): hidden dimension. forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256. dy_graph (bool, optional): if use dynamic graph. Defaults to None. """ super().__init__() # diffusion model - 只保留动态图 self.localized_st_conv = STLocalizedConv(hidden_dim, dy_graph=dy_graph, **model_args) # forecast self.forecast_branch = Forecast(hidden_dim, forecast_hidden_dim=forecast_hidden_dim, **model_args) # backcast self.backcast_branch = nn.Linear(hidden_dim, hidden_dim) # esidual decomposition self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim]) def forward(self, history_data, gated_history_data, dynamic_graph, static_graph=None): """Diffusion block, containing the diffusion model, forecast branch, backcast branch, and the residual decomposition link. Args: history_data (torch.Tensor): history data with shape [batch_size, seq_len, num_nodes, hidden_dim] gated_history_data (torch.Tensor): gated history data with shape [batch_size, seq_len, num_nodes, hidden_dim] dynamic_graph (list): dynamic graphs. static_graph (list, optional): static graphs (未使用). Returns: torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the inherent model. Shape: [batch_size, seq_len', num_nodes, hidden_dim]. Kindly note that after the st conv, the sequence will be shorter. torch.Tensor: the output of the forecast branch, which will be used to make final prediction. Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap. In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point. """ # diffusion model - 只使用动态图 hidden_states_dif = self.localized_st_conv(gated_history_data, dynamic_graph, static_graph) # forecast branch: use the localized st conv to predict future hidden states. forecast_hidden = self.forecast_branch(gated_history_data, hidden_states_dif, self.localized_st_conv, dynamic_graph, static_graph) # backcast branch: use FC layer to do backcast backcast_seq = self.backcast_branch(hidden_states_dif) # residual decomposition: remove the learned knowledge from input data history_data = history_data[:, -backcast_seq.shape[1]:, :, :] backcast_seq_res = self.residual_decompose(history_data, backcast_seq) return backcast_seq_res, forecast_hidden