TrafficWheel/model/STAEFormer/STAEFormer.py

248 lines
8.8 KiB
Python
Executable File

import torch.nn as nn
import torch
class AttentionLayer(nn.Module):
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
Make sure the tensor is permuted to correct shape before attention.
E.g.
- Input shape (batch_size, in_steps, num_nodes, model_dim).
- Then the attention will be performed across the nodes.
Also, it supports different src and tgt length.
But must `src length == K length == V length`.
"""
def __init__(self, model_dim, num_heads=8, mask=False):
super().__init__()
self.model_dim = model_dim
self.num_heads = num_heads
self.mask = mask
self.head_dim = model_dim // num_heads
self.FC_Q = nn.Linear(model_dim, model_dim)
self.FC_K = nn.Linear(model_dim, model_dim)
self.FC_V = nn.Linear(model_dim, model_dim)
self.out_proj = nn.Linear(model_dim, model_dim)
def forward(self, query, key, value):
# Q (batch_size, ..., tgt_length, model_dim)
# K, V (batch_size, ..., src_length, model_dim)
batch_size = query.shape[0]
tgt_length = query.shape[-2]
src_length = key.shape[-2]
query = self.FC_Q(query)
key = self.FC_K(key)
value = self.FC_V(value)
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
key = key.transpose(
-1, -2
) # (num_heads * batch_size, ..., head_dim, src_length)
attn_score = (
query @ key
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
if self.mask:
mask = torch.ones(
tgt_length, src_length, dtype=torch.bool, device=query.device
).tril() # lower triangular part of the matrix
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
attn_score = torch.softmax(attn_score, dim=-1)
out = attn_score @ value # (num_heads * batch_size, ..., tgt_length, head_dim)
out = torch.cat(
torch.split(out, batch_size, dim=0), dim=-1
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)
out = self.out_proj(out)
return out
class SelfAttentionLayer(nn.Module):
def __init__(
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
):
super().__init__()
self.attn = AttentionLayer(model_dim, num_heads, mask)
self.feed_forward = nn.Sequential(
nn.Linear(model_dim, feed_forward_dim),
nn.ReLU(inplace=True),
nn.Linear(feed_forward_dim, model_dim),
)
self.ln1 = nn.LayerNorm(model_dim)
self.ln2 = nn.LayerNorm(model_dim)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, dim=-2):
x = x.transpose(dim, -2)
# x: (batch_size, ..., length, model_dim)
residual = x
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)
out = self.dropout1(out)
out = self.ln1(residual + out)
residual = out
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
out = self.dropout2(out)
out = self.ln2(residual + out)
out = out.transpose(dim, -2)
return out
class STAEformer(nn.Module):
def __init__(self, args):
super().__init__()
self.num_nodes = args["num_nodes"]
self.in_steps = args.get("in_steps", 12)
self.out_steps = args.get("out_steps", 12)
self.steps_per_day = args.get("steps_per_day", 288)
self.input_dim = args.get("input_dim", 3)
self.output_dim = args.get("output_dim", 1)
self.input_embedding_dim = args.get("input_embedding_dim", 24)
self.tod_embedding_dim = args.get("tod_embedding_dim", 24)
self.dow_embedding_dim = args.get("dow_embedding_dim", 24)
self.spatial_embedding_dim = args.get("spatial_embedding_dim", 0)
self.adaptive_embedding_dim = args.get("adaptive_embedding_dim", 80)
self.feed_forward_dim = args.get("feed_forward_dim", 256)
self.num_heads = args.get("num_heads", 4)
self.num_layers = args.get("num_layers", 3)
self.dropout = args.get("dropout", 0.1)
self.use_mixed_proj = args.get("use_mixed_proj", True)
self.model_dim = (
self.input_embedding_dim
+ self.tod_embedding_dim
+ self.dow_embedding_dim
+ self.spatial_embedding_dim
+ self.adaptive_embedding_dim
)
self.input_proj = nn.Linear(self.input_dim, self.input_embedding_dim)
if self.tod_embedding_dim > 0:
self.tod_embedding = nn.Embedding(
self.steps_per_day, self.tod_embedding_dim
)
if self.dow_embedding_dim > 0:
self.dow_embedding = nn.Embedding(7, self.dow_embedding_dim)
if self.spatial_embedding_dim > 0:
self.node_emb = nn.Parameter(
torch.empty(self.num_nodes, self.spatial_embedding_dim)
)
nn.init.xavier_uniform_(self.node_emb)
if self.adaptive_embedding_dim > 0:
self.adaptive_embedding = nn.init.xavier_uniform_(
nn.Parameter(
torch.empty(
self.in_steps, self.num_nodes, self.adaptive_embedding_dim
)
)
)
if self.use_mixed_proj:
self.output_proj = nn.Linear(
self.in_steps * self.model_dim, self.out_steps * self.output_dim
)
else:
self.temporal_proj = nn.Linear(self.in_steps, self.out_steps)
self.output_proj = nn.Linear(self.model_dim, self.output_dim)
self.attn_layers_t = nn.ModuleList(
[
SelfAttentionLayer(
self.model_dim, self.feed_forward_dim, self.num_heads, self.dropout
)
for _ in range(self.num_layers)
]
)
self.attn_layers_s = nn.ModuleList(
[
SelfAttentionLayer(
self.model_dim, self.feed_forward_dim, self.num_heads, self.dropout
)
for _ in range(self.num_layers)
]
)
def forward(self, x):
# x: (batch_size, in_steps, num_nodes, input_dim+tod+dow=3)
batch_size = x.shape[0]
if self.tod_embedding_dim > 0:
tod = x[..., 1]
if self.dow_embedding_dim > 0:
dow = x[..., 2]
x = x[..., 0:1]
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
features = [x]
if self.tod_embedding_dim > 0:
tod_emb = self.tod_embedding(
(tod * self.steps_per_day).long()
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
features.append(tod_emb)
if self.dow_embedding_dim > 0:
dow_emb = self.dow_embedding(
dow.long()
) # (batch_size, in_steps, num_nodes, dow_embedding_dim)
features.append(dow_emb)
if self.spatial_embedding_dim > 0:
spatial_emb = self.node_emb.expand(
batch_size, self.in_steps, *self.node_emb.shape
)
features.append(spatial_emb)
if self.adaptive_embedding_dim > 0:
adp_emb = self.adaptive_embedding.expand(
size=(batch_size, *self.adaptive_embedding.shape)
)
features.append(adp_emb)
x = torch.cat(features, dim=-1) # (batch_size, in_steps, num_nodes, model_dim)
for attn in self.attn_layers_t:
x = attn(x, dim=1)
for attn in self.attn_layers_s:
x = attn(x, dim=2)
# (batch_size, in_steps, num_nodes, model_dim)
if self.use_mixed_proj:
out = x.transpose(1, 2) # (batch_size, num_nodes, in_steps, model_dim)
out = out.reshape(
batch_size, self.num_nodes, self.in_steps * self.model_dim
)
out = self.output_proj(out).view(
batch_size, self.num_nodes, self.out_steps, self.output_dim
)
out = out.transpose(1, 2) # (batch_size, out_steps, num_nodes, output_dim)
else:
out = x.transpose(1, 3) # (batch_size, model_dim, num_nodes, in_steps)
out = self.temporal_proj(
out
) # (batch_size, model_dim, num_nodes, out_steps)
out = self.output_proj(
out.transpose(1, 3)
) # (batch_size, out_steps, num_nodes, output_dim)
return out
if __name__ == "__main__":
model = STAEformer(207, 12, 12)