108 lines
3.6 KiB
Python
108 lines
3.6 KiB
Python
|
|
import torch.nn as nn
|
|
import torch
|
|
from torchinfo import summary
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
|
|
|
|
Make sure the tensor is permuted to correct shape before attention.
|
|
|
|
E.g.
|
|
- Input shape (batch_size, in_steps, num_nodes, model_dim).
|
|
- Then the attention will be performed across the nodes.
|
|
|
|
Also, it supports different src and tgt length.
|
|
|
|
But must `src length == K length == V length`.
|
|
|
|
"""
|
|
|
|
def __init__(self, model_dim, num_heads=8, mask=False):
|
|
super().__init__()
|
|
|
|
self.model_dim = model_dim#152
|
|
self.num_heads = num_heads
|
|
self.mask = mask
|
|
|
|
self.head_dim = model_dim // num_heads
|
|
|
|
self.FC_Q = nn.Linear(model_dim, model_dim)#[152,152]
|
|
self.FC_K = nn.Linear(model_dim, model_dim)
|
|
self.FC_V = nn.Linear(model_dim, model_dim)
|
|
|
|
self.out_proj = nn.Linear(model_dim, model_dim)
|
|
|
|
def forward(self, query, key, value):
|
|
# Q (batch_size, ..., tgt_length, model_dim)
|
|
# K, V (batch_size, ..., src_length, model_dim)
|
|
batch_size = query.shape[0]#16 #64
|
|
tgt_length = query.shape[-2]#12 #170
|
|
src_length = key.shape[-2]#12 #170
|
|
|
|
query = self.FC_Q(query)#[64,6,170,152]
|
|
key = self.FC_K(key)
|
|
value = self.FC_V(value)
|
|
|
|
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
|
|
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)#[512,6,170,24]
|
|
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
|
|
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
|
|
|
|
key = key.transpose(
|
|
-1, -2
|
|
) # (num_heads * batch_size, ..., head_dim, src_length)
|
|
|
|
attn_score = (#[64,170,12,12]
|
|
query @ key
|
|
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
|
|
|
|
if self.mask:
|
|
mask = torch.ones(
|
|
tgt_length, src_length, dtype=torch.bool, device=query.device
|
|
).tril() # lower triangular part of the matrix
|
|
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
|
|
|
|
attn_score = torch.softmax(attn_score, dim=-1)#[64,170,12,12]
|
|
out = attn_score @ value
|
|
out = torch.cat(
|
|
torch.split(out, batch_size, dim=0), dim=-1
|
|
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)[16,170,12,152]
|
|
|
|
out = self.out_proj(out)#[64,6,170,152]
|
|
|
|
return out
|
|
|
|
class SelfAttentionLayer(nn.Module):
|
|
def __init__(
|
|
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn = AttentionLayer(model_dim, num_heads, mask)
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(model_dim, feed_forward_dim),#[152,256]
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(feed_forward_dim, model_dim),#[256.152]
|
|
)
|
|
self.ln1 = nn.LayerNorm(model_dim)
|
|
self.ln2 = nn.LayerNorm(model_dim)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x, dim=-2):
|
|
x = x.transpose(dim, -2)
|
|
# x: (batch_size, ..., length, model_dim)
|
|
residual = x
|
|
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)[16,170,12,152]
|
|
out = self.dropout1(out)
|
|
out = self.ln1(residual + out)
|
|
|
|
residual = out
|
|
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
|
|
out = self.dropout2(out)
|
|
out = self.ln2(residual + out)
|
|
|
|
out = out.transpose(dim, -2)#[64,6,170,152]
|
|
return out |