import torch import torch.nn as nn import torch.nn.functional as F from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack from model.Informer.decoder import Decoder, DecoderLayer from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer from model.Informer.embed import DataEmbedding class Informer(nn.Module): def __init__(self, args): super(Informer, self).__init__() self.pred_len = args['pred_len'] self.attn = args['attn'] self.output_attention = args['output_attention'] # Encoding self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) # Attention Attn = ProbAttention if args['attn']=='prob' else FullAttention # Encoder self.encoder = Encoder( [ EncoderLayer( AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), args['d_model'], args['n_heads'], mix=False), args['d_model'], args['d_ff'], dropout=args['dropout'], activation=args['activation'] ) for l in range(args['e_layers']) ], [ ConvLayer( args['d_model'] ) for l in range(args['e_layers']-1) ] if args['distil'] else None, norm_layer=torch.nn.LayerNorm(args['d_model']) ) # Decoder self.decoder = Decoder( [ DecoderLayer( AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), args['d_model'], args['n_heads'], mix=args['mix']), AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), args['d_model'], args['n_heads'], mix=False), args['d_model'], args['d_ff'], dropout=args['dropout'], activation=args['activation'], ) for l in range(args['d_layers']) ], norm_layer=torch.nn.LayerNorm(args['d_model']) ) self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): enc_out = self.enc_embedding(x_enc, x_mark_enc) enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) dec_out = self.dec_embedding(x_dec, x_mark_dec) dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) dec_out = self.projection(dec_out) if self.output_attention: return dec_out[:,-self.pred_len:,:], attns else: return dec_out[:,-self.pred_len:,:] # [B, L, D] class InformerStack(nn.Module): def __init__(self, args): super(InformerStack, self).__init__() self.pred_len = args['pred_len'] self.attn = args['attn'] self.output_attention = args['output_attention'] # Encoding self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout']) # Attention Attn = ProbAttention if args['attn']=='prob' else FullAttention # Encoder inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here encoders = [ Encoder( [ EncoderLayer( AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']), args['d_model'], args['n_heads'], mix=False), args['d_model'], args['d_ff'], dropout=args['dropout'], activation=args['activation'] ) for l in range(el) ], [ ConvLayer( args['d_model'] ) for l in range(el-1) ] if args['distil'] else None, norm_layer=torch.nn.LayerNorm(args['d_model']) ) for el in args['e_layers']] self.encoder = EncoderStack(encoders, inp_lens) # Decoder self.decoder = Decoder( [ DecoderLayer( AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False), args['d_model'], args['n_heads'], mix=args['mix']), AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False), args['d_model'], args['n_heads'], mix=False), args['d_model'], args['d_ff'], dropout=args['dropout'], activation=args['activation'], ) for l in range(args['d_layers']) ], norm_layer=torch.nn.LayerNorm(args['d_model']) ) self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True) def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None): enc_out = self.enc_embedding(x_enc, x_mark_enc) enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask) dec_out = self.dec_embedding(x_dec, x_mark_dec) dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask) dec_out = self.projection(dec_out) if self.output_attention: return dec_out[:,-self.pred_len:,:], attns else: return dec_out[:,-self.pred_len:,:] # [B, L, D]