25 lines
1.3 KiB
Python
25 lines
1.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class EstimationGate(nn.Module):
|
|
"""The estimation gate module."""
|
|
|
|
def __init__(self, node_emb_dim, time_emb_dim, hidden_dim):
|
|
super().__init__()
|
|
self.fully_connected_layer_1 = nn.Linear(2 * node_emb_dim + time_emb_dim * 2, hidden_dim)
|
|
self.activation = nn.ReLU()
|
|
self.fully_connected_layer_2 = nn.Linear(hidden_dim, 1)
|
|
|
|
def forward(self, node_embedding_u, node_embedding_d, time_in_day_feat, day_in_week_feat, history_data):
|
|
"""Generate gate value in (0, 1) based on current node and time step embeddings to roughly estimating the proportion of the two hidden time series."""
|
|
|
|
batch_size, seq_length, _, _ = time_in_day_feat.shape
|
|
estimation_gate_feat = torch.cat([time_in_day_feat, day_in_week_feat, node_embedding_u.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1), node_embedding_d.unsqueeze(0).unsqueeze(0).expand(batch_size, seq_length, -1, -1)], dim=-1)
|
|
hidden = self.fully_connected_layer_1(estimation_gate_feat)
|
|
hidden = self.activation(hidden)
|
|
# activation
|
|
estimation_gate = torch.sigmoid(self.fully_connected_layer_2(hidden))[:, -history_data.shape[1]:, :, :]
|
|
history_data = history_data * estimation_gate
|
|
return history_data
|