TrafficWheel/model/Informer/model.py

49 lines
1.6 KiB
Python

import torch
import torch.nn as nn
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
from model.Informer.embed import DataEmbedding
from model.Informer.head import TemporalProjectionHead
class InformerEncoder(nn.Module):
def __init__(self, configs):
super().__init__()
self.seq_len = configs["seq_len"]
self.pred_len = configs["pred_len"]
Attn = ProbAttention if configs["attn"] == "prob" else FullAttention
# Embedding
self.embedding = DataEmbedding(configs["enc_in"], configs["d_model"], configs["dropout"])
# Encoder(Attn-Conv-Norm)
self.encoder = Encoder(
[EncoderLayer(
# Attn
AttentionLayer(Attn(False, configs["factor"], configs["dropout"], False),
configs["d_model"], configs["n_heads"], False),
configs["d_model"], configs["d_ff"], configs["dropout"], configs["activation"])
for _ in range(configs["e_layers"])],
# Conv
[ConvLayer(configs["d_model"]) for _ in range(configs["e_layers"] - 1)]
# Norm
if configs.get("distil") else None, norm_layer=nn.LayerNorm(configs["d_model"])
)
# Forecast Head
self.head = TemporalProjectionHead(
d_model=configs["d_model"],
pred_len=configs["pred_len"],
c_out=configs["c_out"],
)
def forward(self, x_enc):
x = self.embedding(x_enc)
x, _ = self.encoder(x)
out = self.head(x)
return out[:, -self.pred_len :, :]