TrafficWheel/model/Informer/model.py

80 lines
2.3 KiB
Python

import torch
import torch.nn as nn
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
from model.Informer.embed import DataEmbedding
from model.Informer.head import TemporalProjectionHead
class InformerEncoder(nn.Module):
"""
Informer Encoder-only
- Only uses x
- No normalization
- Multi-channel friendly
"""
def __init__(self, configs):
super().__init__()
self.seq_len = configs["seq_len"]
self.pred_len = configs["pred_len"]
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
# Embedding
self.embedding = DataEmbedding(
configs["enc_in"],
configs["d_model"],
configs["dropout"],
)
# Encoder (Informer)
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
Attn(
False,
configs["factor"],
attention_dropout=configs["dropout"],
output_attention=False,
),
configs["d_model"],
configs["n_heads"],
mix=False,
),
configs["d_model"],
configs["d_ff"],
dropout=configs["dropout"],
activation=configs["activation"],
)
for _ in range(configs["e_layers"])
],
[
ConvLayer(configs["d_model"])
for _ in range(configs["e_layers"] - 1)
]
if configs.get("distil", False)
else None,
norm_layer=nn.LayerNorm(configs["d_model"]),
)
# Forecast Head
self.head = TemporalProjectionHead(
d_model=configs["d_model"],
pred_len=configs["pred_len"],
c_out=configs["c_out"],
)
def forward(self, x_enc):
"""
x_enc: [B, L, C]
"""
x = self.embedding(x_enc)
x, _ = self.encoder(x)
out = self.head(x)
return out[:, -self.pred_len :, :]