TrafficWheel/model/D2STGNN/inherent_block/inh_block.py

87 lines
4.0 KiB
Python

import math
import torch
import torch.nn as nn
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
from model.D2STGNN.inherent_block.inh_model import RNNLayer, TransformerLayer
from model.D2STGNN.inherent_block.forecast import Forecast
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=None, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(max_len, 1, d_model)
pe[:, 0, 0::2] = torch.sin(position * div_term)
pe[:, 0, 1::2] = torch.cos(position * div_term)
self.register_buffer('pe', pe)
def forward(self, X):
X = X + self.pe[:X.size(0)]
X = self.dropout(X)
return X
class InhBlock(nn.Module):
def __init__(self, hidden_dim, num_heads=4, bias=True, forecast_hidden_dim=256, **model_args):
"""Inherent block
Args:
hidden_dim (int): hidden dimension
num_heads (int, optional): number of heads of MSA. Defaults to 4.
bias (bool, optional): if use bias. Defaults to True.
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
"""
super().__init__()
self.num_feat = hidden_dim
self.hidden_dim = hidden_dim
# inherent model
self.pos_encoder = PositionalEncoding(hidden_dim, model_args['dropout'])
self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout'])
self.transformer_layer = TransformerLayer(hidden_dim, num_heads, model_args['dropout'], bias)
# forecast branch
self.forecast_block = Forecast(hidden_dim, forecast_hidden_dim, **model_args)
# backcast branch
self.backcast_fc = nn.Linear(hidden_dim, hidden_dim)
# residual decomposition
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
def forward(self, hidden_inherent_signal):
"""Inherent block, containing the inherent model, forecast branch, backcast branch, and the residual decomposition link.
Args:
hidden_inherent_signal (torch.Tensor): hidden inherent signal with shape [batch_size, seq_len, num_nodes, num_feat].
Returns:
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the next decouple layer.
Shape: [batch_size, seq_len, num_nodes, hidden_dim].
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
"""
[batch_size, seq_len, num_nodes, num_feat] = hidden_inherent_signal.shape
# inherent model
## rnn
hidden_states_rnn = self.rnn_layer(hidden_inherent_signal)
## pe
hidden_states_rnn = self.pos_encoder(hidden_states_rnn)
## MSA
hidden_states_inh = self.transformer_layer(hidden_states_rnn, hidden_states_rnn, hidden_states_rnn)
# forecast branch
forecast_hidden = self.forecast_block(hidden_inherent_signal, hidden_states_rnn, hidden_states_inh, self.transformer_layer, self.rnn_layer, self.pos_encoder)
# backcast branch
hidden_states_inh = hidden_states_inh.reshape(seq_len, batch_size, num_nodes, num_feat)
hidden_states_inh = hidden_states_inh.transpose(0, 1)
backcast_seq = self.backcast_fc(hidden_states_inh)
backcast_seq_res= self.residual_decompose(hidden_inherent_signal, backcast_seq)
return backcast_seq_res, forecast_hidden