141 lines
6.3 KiB
Python
141 lines
6.3 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
|
|
from model.Informer.decoder import Decoder, DecoderLayer
|
|
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
|
|
from model.Informer.embed import DataEmbedding
|
|
|
|
class Informer(nn.Module):
|
|
def __init__(self, args):
|
|
super(Informer, self).__init__()
|
|
self.pred_len = args['pred_len']
|
|
self.attn = args['attn']
|
|
self.output_attention = args['output_attention']
|
|
|
|
# Encoding
|
|
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
|
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
|
# Attention
|
|
Attn = ProbAttention if args['attn']=='prob' else FullAttention
|
|
# Encoder
|
|
self.encoder = Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
|
|
args['d_model'], args['n_heads'], mix=False),
|
|
args['d_model'],
|
|
args['d_ff'],
|
|
dropout=args['dropout'],
|
|
activation=args['activation']
|
|
) for l in range(args['e_layers'])
|
|
],
|
|
[
|
|
ConvLayer(
|
|
args['d_model']
|
|
) for l in range(args['e_layers']-1)
|
|
] if args['distil'] else None,
|
|
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
|
)
|
|
# Decoder
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
|
args['d_model'], args['n_heads'], mix=args['mix']),
|
|
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
|
args['d_model'], args['n_heads'], mix=False),
|
|
args['d_model'],
|
|
args['d_ff'],
|
|
dropout=args['dropout'],
|
|
activation=args['activation'],
|
|
)
|
|
for l in range(args['d_layers'])
|
|
],
|
|
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
|
)
|
|
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
|
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
|
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
|
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
|
|
|
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
|
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
|
dec_out = self.projection(dec_out)
|
|
|
|
if self.output_attention:
|
|
return dec_out[:,-self.pred_len:,:], attns
|
|
else:
|
|
return dec_out[:,-self.pred_len:,:] # [B, L, D]
|
|
|
|
|
|
class InformerStack(nn.Module):
|
|
def __init__(self, args):
|
|
super(InformerStack, self).__init__()
|
|
self.pred_len = args['pred_len']
|
|
self.attn = args['attn']
|
|
self.output_attention = args['output_attention']
|
|
|
|
# Encoding
|
|
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
|
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
|
|
# Attention
|
|
Attn = ProbAttention if args['attn']=='prob' else FullAttention
|
|
# Encoder
|
|
|
|
inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here
|
|
encoders = [
|
|
Encoder(
|
|
[
|
|
EncoderLayer(
|
|
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
|
|
args['d_model'], args['n_heads'], mix=False),
|
|
args['d_model'],
|
|
args['d_ff'],
|
|
dropout=args['dropout'],
|
|
activation=args['activation']
|
|
) for l in range(el)
|
|
],
|
|
[
|
|
ConvLayer(
|
|
args['d_model']
|
|
) for l in range(el-1)
|
|
] if args['distil'] else None,
|
|
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
|
) for el in args['e_layers']]
|
|
self.encoder = EncoderStack(encoders, inp_lens)
|
|
# Decoder
|
|
self.decoder = Decoder(
|
|
[
|
|
DecoderLayer(
|
|
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
|
args['d_model'], args['n_heads'], mix=args['mix']),
|
|
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
|
|
args['d_model'], args['n_heads'], mix=False),
|
|
args['d_model'],
|
|
args['d_ff'],
|
|
dropout=args['dropout'],
|
|
activation=args['activation'],
|
|
)
|
|
for l in range(args['d_layers'])
|
|
],
|
|
norm_layer=torch.nn.LayerNorm(args['d_model'])
|
|
)
|
|
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
|
|
|
|
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
|
|
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
|
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
|
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
|
|
|
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
|
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
|
dec_out = self.projection(dec_out)
|
|
|
|
if self.output_attention:
|
|
return dec_out[:,-self.pred_len:,:], attns
|
|
else:
|
|
return dec_out[:,-self.pred_len:,:] # [B, L, D] |