TrafficWheel/model/TEDDCF/ISTF.py

111 lines
3.7 KiB
Python

import torch.nn as nn
import torch
from torchinfo import summary
class AttentionLayer(nn.Module):
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
Make sure the tensor is permuted to correct shape before attention.
E.g.
- Input shape (batch_size, in_steps, num_nodes, model_dim).
- Then the attention will be performed across the nodes.
Also, it supports different src and tgt length.
But must `src length == K length == V length`.
"""
def __init__(self, model_dim, num_heads=8, mask=False):
super().__init__()
self.model_dim = model_dim # 152
self.num_heads = num_heads
self.mask = mask
self.head_dim = model_dim // num_heads
self.FC_Q = nn.Linear(model_dim, model_dim) # [152,152]
self.FC_K = nn.Linear(model_dim, model_dim)
self.FC_V = nn.Linear(model_dim, model_dim)
self.out_proj = nn.Linear(model_dim, model_dim)
def forward(self, query, key, value):
# Q (batch_size, ..., tgt_length, model_dim)
# K, V (batch_size, ..., src_length, model_dim)
batch_size = query.shape[0] # 16 #64
tgt_length = query.shape[-2] # 12 #170
src_length = key.shape[-2] # 12 #170
query = self.FC_Q(query) # [64,6,170,152]
key = self.FC_K(key)
value = self.FC_V(value)
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
query = torch.cat(
torch.split(query, self.head_dim, dim=-1), dim=0
) # [512,6,170,24]
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
key = key.transpose(
-1, -2
) # (num_heads * batch_size, ..., head_dim, src_length)
attn_score = ( # [64,170,12,12]
query @ key
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
if self.mask:
mask = torch.ones(
tgt_length, src_length, dtype=torch.bool, device=query.device
).tril() # lower triangular part of the matrix
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
attn_score = torch.softmax(attn_score, dim=-1) # [64,170,12,12]
out = attn_score @ value
out = torch.cat(
torch.split(out, batch_size, dim=0), dim=-1
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)[16,170,12,152]
out = self.out_proj(out) # [64,6,170,152]
return out
class SelfAttentionLayer(nn.Module):
def __init__(
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
):
super().__init__()
self.attn = AttentionLayer(model_dim, num_heads, mask)
self.feed_forward = nn.Sequential(
nn.Linear(model_dim, feed_forward_dim), # [152,256]
nn.ReLU(inplace=True),
nn.Linear(feed_forward_dim, model_dim), # [256.152]
)
self.ln1 = nn.LayerNorm(model_dim)
self.ln2 = nn.LayerNorm(model_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, dim=-2):
x = x.transpose(dim, -2)
# x: (batch_size, ..., length, model_dim)
residual = x
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)[16,170,12,152]
out = self.dropout1(out)
out = self.ln1(residual + out)
residual = out
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
out = self.dropout2(out)
out = self.ln2(residual + out)
out = out.transpose(dim, -2) # [64,6,170,152]
return out