87 lines
4.0 KiB
Python
87 lines
4.0 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from model.D2STGNN.decouple.residual_decomp import ResidualDecomp
|
|
from model.D2STGNN.inherent_block.inh_model import RNNLayer, TransformerLayer
|
|
from model.D2STGNN.inherent_block.forecast import Forecast
|
|
|
|
|
|
class PositionalEncoding(nn.Module):
|
|
def __init__(self, d_model, dropout=None, max_len: int = 5000):
|
|
super().__init__()
|
|
self.dropout = nn.Dropout(p=dropout)
|
|
position = torch.arange(max_len).unsqueeze(1)
|
|
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
|
|
pe = torch.zeros(max_len, 1, d_model)
|
|
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
|
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
|
self.register_buffer('pe', pe)
|
|
|
|
def forward(self, X):
|
|
X = X + self.pe[:X.size(0)]
|
|
X = self.dropout(X)
|
|
return X
|
|
|
|
|
|
class InhBlock(nn.Module):
|
|
def __init__(self, hidden_dim, num_heads=4, bias=True, forecast_hidden_dim=256, **model_args):
|
|
"""Inherent block
|
|
|
|
Args:
|
|
hidden_dim (int): hidden dimension
|
|
num_heads (int, optional): number of heads of MSA. Defaults to 4.
|
|
bias (bool, optional): if use bias. Defaults to True.
|
|
forecast_hidden_dim (int, optional): forecast branch hidden dimension. Defaults to 256.
|
|
"""
|
|
super().__init__()
|
|
self.num_feat = hidden_dim
|
|
self.hidden_dim = hidden_dim
|
|
|
|
# inherent model
|
|
self.pos_encoder = PositionalEncoding(hidden_dim, model_args['dropout'])
|
|
self.rnn_layer = RNNLayer(hidden_dim, model_args['dropout'])
|
|
self.transformer_layer = TransformerLayer(hidden_dim, num_heads, model_args['dropout'], bias)
|
|
|
|
# forecast branch
|
|
self.forecast_block = Forecast(hidden_dim, forecast_hidden_dim, **model_args)
|
|
# backcast branch
|
|
self.backcast_fc = nn.Linear(hidden_dim, hidden_dim)
|
|
# residual decomposition
|
|
self.residual_decompose = ResidualDecomp([-1, -1, -1, hidden_dim])
|
|
|
|
def forward(self, hidden_inherent_signal):
|
|
"""Inherent block, containing the inherent model, forecast branch, backcast branch, and the residual decomposition link.
|
|
|
|
Args:
|
|
hidden_inherent_signal (torch.Tensor): hidden inherent signal with shape [batch_size, seq_len, num_nodes, num_feat].
|
|
|
|
Returns:
|
|
torch.Tensor: the output after the decoupling mechanism (backcast branch and the residual link), which should be fed to the next decouple layer.
|
|
Shape: [batch_size, seq_len, num_nodes, hidden_dim].
|
|
torch.Tensor: the output of the forecast branch, which will be used to make final prediction.
|
|
Shape: [batch_size, seq_len'', num_nodes, forecast_hidden_dim]. seq_len'' = future_len / gap.
|
|
In order to reduce the error accumulation in the AR forecasting strategy, we let each hidden state generate the prediction of gap points, instead of a single point.
|
|
"""
|
|
|
|
[batch_size, seq_len, num_nodes, num_feat] = hidden_inherent_signal.shape
|
|
# inherent model
|
|
## rnn
|
|
hidden_states_rnn = self.rnn_layer(hidden_inherent_signal)
|
|
## pe
|
|
hidden_states_rnn = self.pos_encoder(hidden_states_rnn)
|
|
## MSA
|
|
hidden_states_inh = self.transformer_layer(hidden_states_rnn, hidden_states_rnn, hidden_states_rnn)
|
|
|
|
# forecast branch
|
|
forecast_hidden = self.forecast_block(hidden_inherent_signal, hidden_states_rnn, hidden_states_inh, self.transformer_layer, self.rnn_layer, self.pos_encoder)
|
|
|
|
# backcast branch
|
|
hidden_states_inh = hidden_states_inh.reshape(seq_len, batch_size, num_nodes, num_feat)
|
|
hidden_states_inh = hidden_states_inh.transpose(0, 1)
|
|
backcast_seq = self.backcast_fc(hidden_states_inh)
|
|
backcast_seq_res= self.residual_decompose(hidden_inherent_signal, backcast_seq)
|
|
|
|
return backcast_seq_res, forecast_hidden
|