210 lines
9.2 KiB
Python
210 lines
9.2 KiB
Python
import torch
|
||
import torch.nn as nn
|
||
import torch.nn.functional as F
|
||
|
||
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
|
||
from model.Informer.decoder import Decoder, DecoderLayer
|
||
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
|
||
from model.Informer.embed import DataEmbedding
|
||
from model.Informer.masking import TriangularCausalMask, ProbMask
|
||
|
||
class Informer(nn.Module):
|
||
def __init__(self, configs):
|
||
super(Informer, self).__init__()
|
||
# 从configs中提取参数
|
||
self.enc_in = configs.get("enc_in", 1)
|
||
self.dec_in = configs.get("dec_in", 1)
|
||
self.c_out = configs.get("c_out", 1)
|
||
self.seq_len = configs.get("seq_len", 96)
|
||
self.label_len = configs.get("label_len", 48)
|
||
self.out_len = configs.get("out_len", 24)
|
||
self.factor = configs.get("factor", 5)
|
||
self.d_model = configs.get("d_model", 512)
|
||
self.n_heads = configs.get("n_heads", 8)
|
||
self.e_layers = configs.get("e_layers", 3)
|
||
self.d_layers = configs.get("d_layers", 2)
|
||
self.d_ff = configs.get("d_ff", 512)
|
||
self.dropout = configs.get("dropout", 0.0)
|
||
self.attn = configs.get("attn", "prob")
|
||
self.embed = configs.get("embed", "fixed")
|
||
self.freq = configs.get("freq", "h")
|
||
self.activation = configs.get("activation", "gelu")
|
||
self.output_attention = configs.get("output_attention", False)
|
||
self.distil = configs.get("distil", True)
|
||
self.mix = configs.get("mix", True)
|
||
self.device = configs.get("device", torch.device('cuda:0'))
|
||
|
||
self.pred_len = self.out_len
|
||
|
||
# 编码层
|
||
self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout)
|
||
self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout)
|
||
|
||
# 注意力层
|
||
Attn = ProbAttention if self.attn == 'prob' else FullAttention
|
||
|
||
# 编码器
|
||
self.encoder = Encoder(
|
||
[
|
||
EncoderLayer(
|
||
AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention),
|
||
self.d_model, self.n_heads, mix=False),
|
||
self.d_model,
|
||
self.d_ff,
|
||
dropout=self.dropout,
|
||
activation=self.activation
|
||
) for l in range(self.e_layers)
|
||
],
|
||
[
|
||
ConvLayer(
|
||
self.d_model
|
||
) for l in range(self.e_layers - 1)
|
||
] if self.distil else None,
|
||
norm_layer=torch.nn.LayerNorm(self.d_model)
|
||
)
|
||
|
||
# 解码器
|
||
self.decoder = Decoder(
|
||
[
|
||
DecoderLayer(
|
||
AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False),
|
||
self.d_model, self.n_heads, mix=self.mix),
|
||
AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False),
|
||
self.d_model, self.n_heads, mix=False),
|
||
self.d_model,
|
||
self.d_ff,
|
||
dropout=self.dropout,
|
||
activation=self.activation,
|
||
)
|
||
for l in range(self.d_layers)
|
||
],
|
||
norm_layer=torch.nn.LayerNorm(self.d_model)
|
||
)
|
||
|
||
# 投影层
|
||
self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
|
||
|
||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None,
|
||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
||
# 如果没有提供x_dec和x_mark_dec,则根据x_enc和label_len生成
|
||
if x_dec is None:
|
||
x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1)
|
||
if x_mark_dec is None and x_mark_enc is not None:
|
||
x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1)
|
||
|
||
# 编码
|
||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
||
|
||
# 解码
|
||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
||
dec_out = self.projection(dec_out)
|
||
|
||
if self.output_attention:
|
||
return dec_out[:, -self.pred_len:, :], attns
|
||
else:
|
||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|
||
|
||
|
||
class InformerStack(nn.Module):
|
||
def __init__(self, configs):
|
||
super(InformerStack, self).__init__()
|
||
# 从configs中提取参数
|
||
self.enc_in = configs.get("enc_in", 1)
|
||
self.dec_in = configs.get("dec_in", 1)
|
||
self.c_out = configs.get("c_out", 1)
|
||
self.seq_len = configs.get("seq_len", 96)
|
||
self.label_len = configs.get("label_len", 48)
|
||
self.out_len = configs.get("out_len", 24)
|
||
self.factor = configs.get("factor", 5)
|
||
self.d_model = configs.get("d_model", 512)
|
||
self.n_heads = configs.get("n_heads", 8)
|
||
self.e_layers = configs.get("e_layers", [3, 2, 1])
|
||
self.d_layers = configs.get("d_layers", 2)
|
||
self.d_ff = configs.get("d_ff", 512)
|
||
self.dropout = configs.get("dropout", 0.0)
|
||
self.attn = configs.get("attn", "prob")
|
||
self.embed = configs.get("embed", "fixed")
|
||
self.freq = configs.get("freq", "h")
|
||
self.activation = configs.get("activation", "gelu")
|
||
self.output_attention = configs.get("output_attention", False)
|
||
self.distil = configs.get("distil", True)
|
||
self.mix = configs.get("mix", True)
|
||
self.device = configs.get("device", torch.device('cuda:0'))
|
||
|
||
self.pred_len = self.out_len
|
||
|
||
# 编码层
|
||
self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout)
|
||
self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout)
|
||
|
||
# 注意力层
|
||
Attn = ProbAttention if self.attn == 'prob' else FullAttention
|
||
|
||
# 编码器栈
|
||
inp_lens = list(range(len(self.e_layers))) # [0,1,2,...] you can customize here
|
||
encoders = [
|
||
Encoder(
|
||
[
|
||
EncoderLayer(
|
||
AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention),
|
||
self.d_model, self.n_heads, mix=False),
|
||
self.d_model,
|
||
self.d_ff,
|
||
dropout=self.dropout,
|
||
activation=self.activation
|
||
) for l in range(el)
|
||
],
|
||
[
|
||
ConvLayer(
|
||
self.d_model
|
||
) for l in range(el-1)
|
||
] if self.distil else None,
|
||
norm_layer=torch.nn.LayerNorm(self.d_model)
|
||
) for el in self.e_layers]
|
||
self.encoder = EncoderStack(encoders, inp_lens)
|
||
|
||
# 解码器
|
||
self.decoder = Decoder(
|
||
[
|
||
DecoderLayer(
|
||
AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False),
|
||
self.d_model, self.n_heads, mix=self.mix),
|
||
AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False),
|
||
self.d_model, self.n_heads, mix=False),
|
||
self.d_model,
|
||
self.d_ff,
|
||
dropout=self.dropout,
|
||
activation=self.activation,
|
||
)
|
||
for l in range(self.d_layers)
|
||
],
|
||
norm_layer=torch.nn.LayerNorm(self.d_model)
|
||
)
|
||
|
||
# 投影层
|
||
self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
|
||
|
||
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None,
|
||
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
|
||
# 如果没有提供x_dec和x_mark_dec,则根据x_enc和label_len生成
|
||
if x_dec is None:
|
||
x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1)
|
||
if x_mark_dec is None and x_mark_enc is not None:
|
||
x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1)
|
||
|
||
# 编码
|
||
enc_out = self.enc_embedding(x_enc, x_mark_enc)
|
||
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
|
||
|
||
# 解码
|
||
dec_out = self.dec_embedding(x_dec, x_mark_dec)
|
||
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
|
||
dec_out = self.projection(dec_out)
|
||
|
||
if self.output_attention:
|
||
return dec_out[:, -self.pred_len:, :], attns
|
||
else:
|
||
return dec_out[:, -self.pred_len:, :] # [B, L, D]
|