TrafficWheel/model/Informer/model.py

210 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
from model.Informer.decoder import Decoder, DecoderLayer
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
from model.Informer.embed import DataEmbedding
from model.Informer.masking import TriangularCausalMask, ProbMask
class Informer(nn.Module):
def __init__(self, configs):
super(Informer, self).__init__()
# 从configs中提取参数
self.enc_in = configs.get("enc_in", 1)
self.dec_in = configs.get("dec_in", 1)
self.c_out = configs.get("c_out", 1)
self.seq_len = configs.get("seq_len", 96)
self.label_len = configs.get("label_len", 48)
self.out_len = configs.get("out_len", 24)
self.factor = configs.get("factor", 5)
self.d_model = configs.get("d_model", 512)
self.n_heads = configs.get("n_heads", 8)
self.e_layers = configs.get("e_layers", 3)
self.d_layers = configs.get("d_layers", 2)
self.d_ff = configs.get("d_ff", 512)
self.dropout = configs.get("dropout", 0.0)
self.attn = configs.get("attn", "prob")
self.embed = configs.get("embed", "fixed")
self.freq = configs.get("freq", "h")
self.activation = configs.get("activation", "gelu")
self.output_attention = configs.get("output_attention", False)
self.distil = configs.get("distil", True)
self.mix = configs.get("mix", True)
self.device = configs.get("device", torch.device('cuda:0'))
self.pred_len = self.out_len
# 编码层
self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout)
self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout)
# 注意力层
Attn = ProbAttention if self.attn == 'prob' else FullAttention
# 编码器
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention),
self.d_model, self.n_heads, mix=False),
self.d_model,
self.d_ff,
dropout=self.dropout,
activation=self.activation
) for l in range(self.e_layers)
],
[
ConvLayer(
self.d_model
) for l in range(self.e_layers - 1)
] if self.distil else None,
norm_layer=torch.nn.LayerNorm(self.d_model)
)
# 解码器
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False),
self.d_model, self.n_heads, mix=self.mix),
AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False),
self.d_model, self.n_heads, mix=False),
self.d_model,
self.d_ff,
dropout=self.dropout,
activation=self.activation,
)
for l in range(self.d_layers)
],
norm_layer=torch.nn.LayerNorm(self.d_model)
)
# 投影层
self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
# 如果没有提供x_dec和x_mark_dec则根据x_enc和label_len生成
if x_dec is None:
x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1)
if x_mark_dec is None and x_mark_enc is not None:
x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1)
# 编码
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
# 解码
dec_out = self.dec_embedding(x_dec, x_mark_dec)
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
dec_out = self.projection(dec_out)
if self.output_attention:
return dec_out[:, -self.pred_len:, :], attns
else:
return dec_out[:, -self.pred_len:, :] # [B, L, D]
class InformerStack(nn.Module):
def __init__(self, configs):
super(InformerStack, self).__init__()
# 从configs中提取参数
self.enc_in = configs.get("enc_in", 1)
self.dec_in = configs.get("dec_in", 1)
self.c_out = configs.get("c_out", 1)
self.seq_len = configs.get("seq_len", 96)
self.label_len = configs.get("label_len", 48)
self.out_len = configs.get("out_len", 24)
self.factor = configs.get("factor", 5)
self.d_model = configs.get("d_model", 512)
self.n_heads = configs.get("n_heads", 8)
self.e_layers = configs.get("e_layers", [3, 2, 1])
self.d_layers = configs.get("d_layers", 2)
self.d_ff = configs.get("d_ff", 512)
self.dropout = configs.get("dropout", 0.0)
self.attn = configs.get("attn", "prob")
self.embed = configs.get("embed", "fixed")
self.freq = configs.get("freq", "h")
self.activation = configs.get("activation", "gelu")
self.output_attention = configs.get("output_attention", False)
self.distil = configs.get("distil", True)
self.mix = configs.get("mix", True)
self.device = configs.get("device", torch.device('cuda:0'))
self.pred_len = self.out_len
# 编码层
self.enc_embedding = DataEmbedding(self.enc_in, self.d_model, self.embed, self.freq, self.dropout)
self.dec_embedding = DataEmbedding(self.dec_in, self.d_model, self.embed, self.freq, self.dropout)
# 注意力层
Attn = ProbAttention if self.attn == 'prob' else FullAttention
# 编码器栈
inp_lens = list(range(len(self.e_layers))) # [0,1,2,...] you can customize here
encoders = [
Encoder(
[
EncoderLayer(
AttentionLayer(Attn(False, self.factor, attention_dropout=self.dropout, output_attention=self.output_attention),
self.d_model, self.n_heads, mix=False),
self.d_model,
self.d_ff,
dropout=self.dropout,
activation=self.activation
) for l in range(el)
],
[
ConvLayer(
self.d_model
) for l in range(el-1)
] if self.distil else None,
norm_layer=torch.nn.LayerNorm(self.d_model)
) for el in self.e_layers]
self.encoder = EncoderStack(encoders, inp_lens)
# 解码器
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(Attn(True, self.factor, attention_dropout=self.dropout, output_attention=False),
self.d_model, self.n_heads, mix=self.mix),
AttentionLayer(FullAttention(False, self.factor, attention_dropout=self.dropout, output_attention=False),
self.d_model, self.n_heads, mix=False),
self.d_model,
self.d_ff,
dropout=self.dropout,
activation=self.activation,
)
for l in range(self.d_layers)
],
norm_layer=torch.nn.LayerNorm(self.d_model)
)
# 投影层
self.projection = nn.Linear(self.d_model, self.c_out, bias=True)
def forward(self, x_enc, x_mark_enc=None, x_dec=None, x_mark_dec=None,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
# 如果没有提供x_dec和x_mark_dec则根据x_enc和label_len生成
if x_dec is None:
x_dec = torch.cat([x_enc[:, -self.label_len:, :], torch.zeros_like(x_enc[:, :self.pred_len, :])], dim=1)
if x_mark_dec is None and x_mark_enc is not None:
x_mark_dec = torch.cat([x_mark_enc[:, -self.label_len:, :], torch.zeros_like(x_mark_enc[:, :self.pred_len, :])], dim=1)
# 编码
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
# 解码
dec_out = self.dec_embedding(x_dec, x_mark_dec)
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
dec_out = self.projection(dec_out)
if self.output_attention:
return dec_out[:, -self.pred_len:, :], attns
else:
return dec_out[:, -self.pred_len:, :] # [B, L, D]