TrafficWheel/model/D2STGNN/diffusion_block/dif_block.py

57 lines
3.1 KiB
Python

import torch.nn as nn
from model.D2STGNN.diffusion_block.forecast import Forecast
from model.D2STGNN.diffusion_block.dif_model import STLocalizedConv
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
class DifBlock(nn.Module):
def __init__(self, hidden_dim, forecast_hidden_dim=256, dy_graph=None, **model_args):
"""Diffusion block
Args:
hidden_dim (int): hidden dimension.
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
dy_graph (bool, optional): if use dynamic graph. Defaults to None.
"""
super().__init__()
# diffusion model - 只保留动态图
self.localized_st_conv = STLocalizedConv(hidden_dim, dy_graph=dy_graph, **model_args)
# forecast
self.forecast_branch = Forecast(hidden_dim, forecast_hidden_dim=forecast_hidden_dim, **model_args)
# backcast
self.backcast_branch = nn.Linear(hidden_dim, hidden_dim)
# esidual decomposition
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
def forward(self, history_data, gated_history_data, dynamic_graph, static_graph=None):
"""Diffusion block, containing the diffusion model, forecast branch, backcast branch, and the residual decomposition link.
Args:
history_data (torch.Tensor): history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
gated_history_data (torch.Tensor): gated history data with shape [batch_size, seq_len, num_nodes, hidden_dim]
dynamic_graph (list): dynamic graphs.
static_graph (list, optional): static graphs (未使用).
Returns:
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the inherent model.
Shape: [batch_size, seq_len', num_nodes, hidden_dim]. Kindly note that after the st conv, the sequence will be shorter.
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
"""
# diffusion model - 只使用动态图
hidden_states_dif = self.localized_st_conv(gated_history_data, dynamic_graph, static_graph)
# forecast branch: use the localized st conv to predict future hidden states.
forecast_hidden = self.forecast_branch(gated_history_data, hidden_states_dif, self.localized_st_conv, dynamic_graph, static_graph)
# backcast branch: use FC layer to do backcast
backcast_seq = self.backcast_branch(hidden_states_dif)
# residual decomposition: remove the learned knowledge from input data
history_data = history_data[:, -backcast_seq.shape[1]:, :, :]
backcast_seq_res = self.residual_decompose(history_data, backcast_seq)
return backcast_seq_res, forecast_hidden