import torch import torch.nn as nn from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer from model.Informer.embed import DataEmbedding from model.Informer.head import TemporalProjectionHead class InformerEncoder(nn.Module): """ Informer Encoder-only - Only uses x - No normalization - Multi-channel friendly """ def __init__(self, configs): super().__init__() self.seq_len = configs["seq_len"] self.pred_len = configs["pred_len"] Attn = ProbAttention if configs["attn"] == "prob" else FullAttention # Embedding self.embedding = DataEmbedding( configs["enc_in"], configs["d_model"], configs["dropout"], ) # Encoder (Informer) self.encoder = Encoder( [ EncoderLayer( AttentionLayer( Attn( False, configs["factor"], attention_dropout=configs["dropout"], output_attention=False, ), configs["d_model"], configs["n_heads"], mix=False, ), configs["d_model"], configs["d_ff"], dropout=configs["dropout"], activation=configs["activation"], ) for _ in range(configs["e_layers"]) ], [ ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1) ] if configs.get("distil", False) else None, norm_layer=nn.LayerNorm(configs["d_model"]), ) # Forecast Head self.head = TemporalProjectionHead( d_model=configs["d_model"], pred_len=configs["pred_len"], c_out=configs["c_out"], ) def forward(self, x_enc): """ x_enc: [B, L, C] """ x = self.embedding(x_enc) x, _ = self.encoder(x) out = self.head(x) return out[:, -self.pred_len :, :]