36 lines
1.3 KiB
Python
36 lines
1.3 KiB
Python
import torch as th
|
|
import torch.nn as nn
|
|
from torch.nn import MultiheadAttention
|
|
|
|
|
|
class RNNLayer(nn.Module):
|
|
def __init__(self, hidden_dim, dropout=None):
|
|
super().__init__()
|
|
self.hidden_dim = hidden_dim
|
|
self.gru_cell = nn.GRUCell(hidden_dim, hidden_dim)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, X):
|
|
[batch_size, seq_len, num_nodes, hidden_dim] = X.shape
|
|
X = X.transpose(1, 2).reshape(batch_size * num_nodes, seq_len, hidden_dim)
|
|
hx = th.zeros_like(X[:, 0, :])
|
|
output = []
|
|
for _ in range(X.shape[1]):
|
|
hx = self.gru_cell(X[:, _, :], hx)
|
|
output.append(hx)
|
|
output = th.stack(output, dim=0)
|
|
output = self.dropout(output)
|
|
return output
|
|
|
|
|
|
class TransformerLayer(nn.Module):
|
|
def __init__(self, hidden_dim, num_heads=4, dropout=None, bias=True):
|
|
super().__init__()
|
|
self.multi_head_self_attention = MultiheadAttention(hidden_dim, num_heads, dropout=dropout, bias=bias)
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
|
def forward(self, X, K, V):
|
|
hidden_states_MSA = self.multi_head_self_attention(X, K, V)[0]
|
|
hidden_states_MSA = self.dropout(hidden_states_MSA)
|
|
return hidden_states_MSA
|