248 lines
8.8 KiB
Python
Executable File
248 lines
8.8 KiB
Python
Executable File
import torch.nn as nn
|
|
import torch
|
|
|
|
|
|
class AttentionLayer(nn.Module):
|
|
"""Perform attention across the -2 dim (the -1 dim is `model_dim`).
|
|
|
|
Make sure the tensor is permuted to correct shape before attention.
|
|
|
|
E.g.
|
|
- Input shape (batch_size, in_steps, num_nodes, model_dim).
|
|
- Then the attention will be performed across the nodes.
|
|
|
|
Also, it supports different src and tgt length.
|
|
|
|
But must `src length == K length == V length`.
|
|
|
|
"""
|
|
|
|
def __init__(self, model_dim, num_heads=8, mask=False):
|
|
super().__init__()
|
|
|
|
self.model_dim = model_dim
|
|
self.num_heads = num_heads
|
|
self.mask = mask
|
|
|
|
self.head_dim = model_dim // num_heads
|
|
|
|
self.FC_Q = nn.Linear(model_dim, model_dim)
|
|
self.FC_K = nn.Linear(model_dim, model_dim)
|
|
self.FC_V = nn.Linear(model_dim, model_dim)
|
|
|
|
self.out_proj = nn.Linear(model_dim, model_dim)
|
|
|
|
def forward(self, query, key, value):
|
|
# Q (batch_size, ..., tgt_length, model_dim)
|
|
# K, V (batch_size, ..., src_length, model_dim)
|
|
batch_size = query.shape[0]
|
|
tgt_length = query.shape[-2]
|
|
src_length = key.shape[-2]
|
|
|
|
query = self.FC_Q(query)
|
|
key = self.FC_K(key)
|
|
value = self.FC_V(value)
|
|
|
|
# Qhead, Khead, Vhead (num_heads * batch_size, ..., length, head_dim)
|
|
query = torch.cat(torch.split(query, self.head_dim, dim=-1), dim=0)
|
|
key = torch.cat(torch.split(key, self.head_dim, dim=-1), dim=0)
|
|
value = torch.cat(torch.split(value, self.head_dim, dim=-1), dim=0)
|
|
|
|
key = key.transpose(
|
|
-1, -2
|
|
) # (num_heads * batch_size, ..., head_dim, src_length)
|
|
|
|
attn_score = (
|
|
query @ key
|
|
) / self.head_dim**0.5 # (num_heads * batch_size, ..., tgt_length, src_length)
|
|
|
|
if self.mask:
|
|
mask = torch.ones(
|
|
tgt_length, src_length, dtype=torch.bool, device=query.device
|
|
).tril() # lower triangular part of the matrix
|
|
attn_score.masked_fill_(~mask, -torch.inf) # fill in-place
|
|
|
|
attn_score = torch.softmax(attn_score, dim=-1)
|
|
out = attn_score @ value # (num_heads * batch_size, ..., tgt_length, head_dim)
|
|
out = torch.cat(
|
|
torch.split(out, batch_size, dim=0), dim=-1
|
|
) # (batch_size, ..., tgt_length, head_dim * num_heads = model_dim)
|
|
|
|
out = self.out_proj(out)
|
|
|
|
return out
|
|
|
|
|
|
class SelfAttentionLayer(nn.Module):
|
|
def __init__(
|
|
self, model_dim, feed_forward_dim=2048, num_heads=8, dropout=0, mask=False
|
|
):
|
|
super().__init__()
|
|
|
|
self.attn = AttentionLayer(model_dim, num_heads, mask)
|
|
self.feed_forward = nn.Sequential(
|
|
nn.Linear(model_dim, feed_forward_dim),
|
|
nn.ReLU(inplace=True),
|
|
nn.Linear(feed_forward_dim, model_dim),
|
|
)
|
|
self.ln1 = nn.LayerNorm(model_dim)
|
|
self.ln2 = nn.LayerNorm(model_dim)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x, dim=-2):
|
|
x = x.transpose(dim, -2)
|
|
# x: (batch_size, ..., length, model_dim)
|
|
residual = x
|
|
out = self.attn(x, x, x) # (batch_size, ..., length, model_dim)
|
|
out = self.dropout1(out)
|
|
out = self.ln1(residual + out)
|
|
|
|
residual = out
|
|
out = self.feed_forward(out) # (batch_size, ..., length, model_dim)
|
|
out = self.dropout2(out)
|
|
out = self.ln2(residual + out)
|
|
|
|
out = out.transpose(dim, -2)
|
|
return out
|
|
|
|
|
|
class STAEformer(nn.Module):
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.num_nodes = args["num_nodes"]
|
|
self.in_steps = args.get("in_steps", 12)
|
|
self.out_steps = args.get("out_steps", 12)
|
|
self.steps_per_day = args.get("steps_per_day", 288)
|
|
self.input_dim = args.get("input_dim", 3)
|
|
self.output_dim = args.get("output_dim", 1)
|
|
self.input_embedding_dim = args.get("input_embedding_dim", 24)
|
|
self.tod_embedding_dim = args.get("tod_embedding_dim", 24)
|
|
self.dow_embedding_dim = args.get("dow_embedding_dim", 24)
|
|
self.spatial_embedding_dim = args.get("spatial_embedding_dim", 0)
|
|
self.adaptive_embedding_dim = args.get("adaptive_embedding_dim", 80)
|
|
self.feed_forward_dim = args.get("feed_forward_dim", 256)
|
|
self.num_heads = args.get("num_heads", 4)
|
|
self.num_layers = args.get("num_layers", 3)
|
|
self.dropout = args.get("dropout", 0.1)
|
|
self.use_mixed_proj = args.get("use_mixed_proj", True)
|
|
|
|
self.model_dim = (
|
|
self.input_embedding_dim
|
|
+ self.tod_embedding_dim
|
|
+ self.dow_embedding_dim
|
|
+ self.spatial_embedding_dim
|
|
+ self.adaptive_embedding_dim
|
|
)
|
|
|
|
self.input_proj = nn.Linear(self.input_dim, self.input_embedding_dim)
|
|
if self.tod_embedding_dim > 0:
|
|
self.tod_embedding = nn.Embedding(
|
|
self.steps_per_day, self.tod_embedding_dim
|
|
)
|
|
if self.dow_embedding_dim > 0:
|
|
self.dow_embedding = nn.Embedding(7, self.dow_embedding_dim)
|
|
if self.spatial_embedding_dim > 0:
|
|
self.node_emb = nn.Parameter(
|
|
torch.empty(self.num_nodes, self.spatial_embedding_dim)
|
|
)
|
|
nn.init.xavier_uniform_(self.node_emb)
|
|
if self.adaptive_embedding_dim > 0:
|
|
self.adaptive_embedding = nn.init.xavier_uniform_(
|
|
nn.Parameter(
|
|
torch.empty(
|
|
self.in_steps, self.num_nodes, self.adaptive_embedding_dim
|
|
)
|
|
)
|
|
)
|
|
|
|
if self.use_mixed_proj:
|
|
self.output_proj = nn.Linear(
|
|
self.in_steps * self.model_dim, self.out_steps * self.output_dim
|
|
)
|
|
else:
|
|
self.temporal_proj = nn.Linear(self.in_steps, self.out_steps)
|
|
self.output_proj = nn.Linear(self.model_dim, self.output_dim)
|
|
|
|
self.attn_layers_t = nn.ModuleList(
|
|
[
|
|
SelfAttentionLayer(
|
|
self.model_dim, self.feed_forward_dim, self.num_heads, self.dropout
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
self.attn_layers_s = nn.ModuleList(
|
|
[
|
|
SelfAttentionLayer(
|
|
self.model_dim, self.feed_forward_dim, self.num_heads, self.dropout
|
|
)
|
|
for _ in range(self.num_layers)
|
|
]
|
|
)
|
|
|
|
def forward(self, x):
|
|
# x: (batch_size, in_steps, num_nodes, input_dim+tod+dow=3)
|
|
batch_size = x.shape[0]
|
|
|
|
if self.tod_embedding_dim > 0:
|
|
tod = x[..., 1]
|
|
if self.dow_embedding_dim > 0:
|
|
dow = x[..., 2]
|
|
x = x[..., 0:1]
|
|
|
|
x = self.input_proj(x) # (batch_size, in_steps, num_nodes, input_embedding_dim)
|
|
features = [x]
|
|
if self.tod_embedding_dim > 0:
|
|
tod_emb = self.tod_embedding(
|
|
(tod * self.steps_per_day).long()
|
|
) # (batch_size, in_steps, num_nodes, tod_embedding_dim)
|
|
features.append(tod_emb)
|
|
if self.dow_embedding_dim > 0:
|
|
dow_emb = self.dow_embedding(
|
|
dow.long()
|
|
) # (batch_size, in_steps, num_nodes, dow_embedding_dim)
|
|
features.append(dow_emb)
|
|
if self.spatial_embedding_dim > 0:
|
|
spatial_emb = self.node_emb.expand(
|
|
batch_size, self.in_steps, *self.node_emb.shape
|
|
)
|
|
features.append(spatial_emb)
|
|
if self.adaptive_embedding_dim > 0:
|
|
adp_emb = self.adaptive_embedding.expand(
|
|
size=(batch_size, *self.adaptive_embedding.shape)
|
|
)
|
|
features.append(adp_emb)
|
|
x = torch.cat(features, dim=-1) # (batch_size, in_steps, num_nodes, model_dim)
|
|
|
|
for attn in self.attn_layers_t:
|
|
x = attn(x, dim=1)
|
|
for attn in self.attn_layers_s:
|
|
x = attn(x, dim=2)
|
|
# (batch_size, in_steps, num_nodes, model_dim)
|
|
|
|
if self.use_mixed_proj:
|
|
out = x.transpose(1, 2) # (batch_size, num_nodes, in_steps, model_dim)
|
|
out = out.reshape(
|
|
batch_size, self.num_nodes, self.in_steps * self.model_dim
|
|
)
|
|
out = self.output_proj(out).view(
|
|
batch_size, self.num_nodes, self.out_steps, self.output_dim
|
|
)
|
|
out = out.transpose(1, 2) # (batch_size, out_steps, num_nodes, output_dim)
|
|
else:
|
|
out = x.transpose(1, 3) # (batch_size, model_dim, num_nodes, in_steps)
|
|
out = self.temporal_proj(
|
|
out
|
|
) # (batch_size, model_dim, num_nodes, out_steps)
|
|
out = self.output_proj(
|
|
out.transpose(1, 3)
|
|
) # (batch_size, out_steps, num_nodes, output_dim)
|
|
|
|
return out
|
|
|
|
|
|
if __name__ == "__main__":
|
|
model = STAEformer(207, 12, 12)
|