import torch import torch.nn as nn import torch.nn.functional as F from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack from model.Informer.decoder import Decoder, DecoderLayer from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer from model.Informer.embed import DataEmbedding from model.Informer.masking import TriangularCausalMask, ProbMask class Informer(nn.Module): def __init__(self, configs): super(Informer, self).__init__() # 从configs中提取参数 self.enc_in = configs.get("enc_in", 1) self.dec_in = configs.get("dec_in", 1) self.c_out = configs.get("c_out", 1) self.seq_len = configs.get("seq_len", 96) self.label_len = configs.get("label_len", 48) self.out_len = configs.get("out_len", 24) self.factor = configs.get("factor", 5) self.d_model = configs.get("d_model", 512) self.n_heads = configs.get("n_heads", 8) self.e_layers = configs.get("e_layers", 3) self.d_layers = configs.get("d_layers", 2) self.d_ff = configs.get("d_ff", 512) self.dropout = configs.get("dropout", 0.0) self.attn = configs.get("attn", "prob") self.embed = configs.get("embed", "fixed") self.freq = configs.get("freq", "h") self.activation = configs.get("activation", "gelu") self.output_attention = configs.get("output_attention", False) self.distil = configs.get("distil", True) self.mix = configs.get("mix", True) self.device = configs.get("device", torch.device('cuda:0')) self.pred_len = self.out_len # 编码层 self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout) self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout) # 注意力层 Attn = ProbAttention if self.attn == 'prob' else FullAttention # 编码器 self.encoder = Encoder( [ EncoderLayer( AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention), self.d_model, self.n_heads, mix=False), self.d_model, self.d_ff, dropout=self.dropout, activation=self.activation ) for l in range(self.e_layers) ], [ ConvLayer( self.d_model ) for l in range(self.e_layers - 1) ] if self.distil else None, norm_layer=torch.nn.LayerNorm(self.d_model) ) # 解码器 self.decoder = Decoder( [ DecoderLayer( AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False), self.d_model, self.n_heads, mix=self.mix), AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False), self.d_model, self.n_heads, mix=False), self.d_model, self.d_ff, dropout=self.dropout, activation=self.activation, ) for l in range(self.d_layers) ], norm_layer=torch.nn.LayerNorm(self.d_model) ) # 投影层 self.projection = nn.Linear(self.d_model, self.c_out, bias=True) def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): # 如果没有提供x_dec和x_mark_dec,则根据x_enc和label_len生成 if x_dec is None: x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1) if x_mark_dec is None and x_mark_enc is not None: x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1) # 编码 enc_out = self.enc_embedding(x_enc, x_mark_enc) enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) # 解码 dec_out = self.dec_embedding(x_dec, x_mark_dec) dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) dec_out = self.projection(dec_out) if self.output_attention: return dec_out[:, -self.pred_len:, :], attns else: return dec_out[:, -self.pred_len:, :] # [B, L, D] class InformerStack(nn.Module): def __init__(self, configs): super(InformerStack, self).__init__() # 从configs中提取参数 self.enc_in = configs.get("enc_in", 1) self.dec_in = configs.get("dec_in", 1) self.c_out = configs.get("c_out", 1) self.seq_len = configs.get("seq_len", 96) self.label_len = configs.get("label_len", 48) self.out_len = configs.get("out_len", 24) self.factor = configs.get("factor", 5) self.d_model = configs.get("d_model", 512) self.n_heads = configs.get("n_heads", 8) self.e_layers = configs.get("e_layers", [3, 2, 1]) self.d_layers = configs.get("d_layers", 2) self.d_ff = configs.get("d_ff", 512) self.dropout = configs.get("dropout", 0.0) self.attn = configs.get("attn", "prob") self.embed = configs.get("embed", "fixed") self.freq = configs.get("freq", "h") self.activation = configs.get("activation", "gelu") self.output_attention = configs.get("output_attention", False) self.distil = configs.get("distil", True) self.mix = configs.get("mix", True) self.device = configs.get("device", torch.device('cuda:0')) self.pred_len = self.out_len # 编码层 self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout) self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout) # 注意力层 Attn = ProbAttention if self.attn == 'prob' else FullAttention # 编码器栈 inp_lens = list(range(len(self.e_layers))) # [0,1,2,...] you can customize here encoders = [ Encoder( [ EncoderLayer( AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention), self.d_model, self.n_heads, mix=False), self.d_model, self.d_ff, dropout=self.dropout, activation=self.activation ) for l in range(el) ], [ ConvLayer( self.d_model ) for l in range(el-1) ] if self.distil else None, norm_layer=torch.nn.LayerNorm(self.d_model) ) for el in self.e_layers] self.encoder = EncoderStack(encoders, inp_lens) # 解码器 self.decoder = Decoder( [ DecoderLayer( AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False), self.d_model, self.n_heads, mix=self.mix), AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False), self.d_model, self.n_heads, mix=False), self.d_model, self.d_ff, dropout=self.dropout, activation=self.activation, ) for l in range(self.d_layers) ], norm_layer=torch.nn.LayerNorm(self.d_model) ) # 投影层 self.projection = nn.Linear(self.d_model, self.c_out, bias=True) def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): # 如果没有提供x_dec和x_mark_dec,则根据x_enc和label_len生成 if x_dec is None: x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1) if x_mark_dec is None and x_mark_enc is not None: x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1) # 编码 enc_out = self.enc_embedding(x_enc, x_mark_enc) enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) # 解码 dec_out = self.dec_embedding(x_dec, x_mark_dec) dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) dec_out = self.projection(dec_out) if self.output_attention: return dec_out[:, -self.pred_len:, :], attns else: return dec_out[:, -self.pred_len:, :] # [B, L, D]