import torch import torch.nn as nn class EstimationGate(nn.Module): """The estimation gate module.""" def __init__(self, node_emb_dim, time_emb_dim, hidden_dim): super().__init__() self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim) self.activation = nn.ReLU() self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1) def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data): """Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series.""" batch_size, seq_length, _, _ = time_in_day_feat.shape estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1) hidden = self.fully_connected_layer_1(estimation_gate_feat) hidden = self.activation(hidden) # activation estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :] history_data = history_data * estimation_gate return history_data