TrafficWheel/model/D2STGNN/decouple/estimation_gate.py

25 lines
1.3 KiB
Python

import torch
import torch.nn as nn
class EstimationGate(nn.Module):
"""The estimation gate module."""
def __init__(self, node_emb_dim, time_emb_dim, hidden_dim):
super().__init__()
self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim)
self.activation = nn.ReLU()
self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1)
def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data):
"""Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series."""
batch_size, seq_length, _, _ = time_in_day_feat.shape
estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1)
hidden = self.fully_connected_layer_1(estimation_gate_feat)
hidden = self.activation(hidden)
# activation
estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :]
history_data = history_data * estimation_gate
return history_data