49 lines
1.6 KiB
Python
49 lines
1.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
|
|
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer
|
|
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
|
|
from model.Informer.embed import DataEmbedding
|
|
from model.Informer.head import TemporalProjectionHead
|
|
|
|
|
|
class InformerEncoder(nn.Module):
|
|
|
|
def __init__(self, configs):
|
|
super().__init__()
|
|
|
|
self.seq_len = configs["seq_len"]
|
|
self.pred_len = configs["pred_len"]
|
|
|
|
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
|
|
|
|
# Embedding
|
|
self.embedding = DataEmbedding(configs["enc_in"], configs["d_model"], configs["dropout"])
|
|
|
|
# Encoder(Attn-Conv-Norm)
|
|
self.encoder = Encoder(
|
|
[EncoderLayer(
|
|
# Attn
|
|
AttentionLayer(Attn(False, configs["factor"], configs["dropout"], False),
|
|
configs["d_model"], configs["n_heads"], False),
|
|
configs["d_model"], configs["d_ff"], configs["dropout"], configs["activation"])
|
|
for _ in range(configs["e_layers"])],
|
|
# Conv
|
|
[ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1)]
|
|
# Norm
|
|
if configs.get("distil") else None, norm_layer=nn.LayerNorm(configs["d_model"])
|
|
)
|
|
|
|
# Forecast Head
|
|
self.head = TemporalProjectionHead(
|
|
d_model=configs["d_model"],
|
|
pred_len=configs["pred_len"],
|
|
c_out=configs["c_out"],
|
|
)
|
|
|
|
def forward(self, x_enc):
|
|
x = self.embedding(x_enc)
|
|
x, _ = self.encoder(x)
|
|
out = self.head(x)
|
|
return out[:, -self.pred_len :, :]
|