TrafficWheel/model/Informer/model.py

141 lines
6.3 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.Informer.encoder import Encoder, EncoderLayer, ConvLayer, EncoderStack
from model.Informer.decoder import Decoder, DecoderLayer
from model.Informer.attn import FullAttention, ProbAttention, AttentionLayer
from model.Informer.embed import DataEmbedding
class Informer(nn.Module):
def __init__(self, args):
super(Informer, self).__init__()
self.pred_len = args['pred_len']
self.attn = args['attn']
self.output_attention = args['output_attention']
# Encoding
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
# Attention
Attn = ProbAttention if args['attn']=='prob' else FullAttention
# Encoder
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
args['d_model'], args['n_heads'], mix=False),
args['d_model'],
args['d_ff'],
dropout=args['dropout'],
activation=args['activation']
) for l in range(args['e_layers'])
],
[
ConvLayer(
args['d_model']
) for l in range(args['e_layers']-1)
] if args['distil'] else None,
norm_layer=torch.nn.LayerNorm(args['d_model'])
)
# Decoder
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
args['d_model'], args['n_heads'], mix=args['mix']),
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
args['d_model'], args['n_heads'], mix=False),
args['d_model'],
args['d_ff'],
dropout=args['dropout'],
activation=args['activation'],
)
for l in range(args['d_layers'])
],
norm_layer=torch.nn.LayerNorm(args['d_model'])
)
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
dec_out = self.projection(dec_out)
if self.output_attention:
return dec_out[:,-self.pred_len:,:], attns
else:
return dec_out[:,-self.pred_len:,:] # [B, L, D]
class InformerStack(nn.Module):
def __init__(self, args):
super(InformerStack, self).__init__()
self.pred_len = args['pred_len']
self.attn = args['attn']
self.output_attention = args['output_attention']
# Encoding
self.enc_embedding = DataEmbedding(args['enc_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
self.dec_embedding = DataEmbedding(args['dec_in'], args['d_model'], args['embed'], args['freq'], args['dropout'])
# Attention
Attn = ProbAttention if args['attn']=='prob' else FullAttention
# Encoder
inp_lens = list(range(len(args['e_layers']))) # [0,1,2,...] you can customize here
encoders = [
Encoder(
[
EncoderLayer(
AttentionLayer(Attn(False, args['factor'], attention_dropout=args['dropout'], output_attention=args['output_attention']),
args['d_model'], args['n_heads'], mix=False),
args['d_model'],
args['d_ff'],
dropout=args['dropout'],
activation=args['activation']
) for l in range(el)
],
[
ConvLayer(
args['d_model']
) for l in range(el-1)
] if args['distil'] else None,
norm_layer=torch.nn.LayerNorm(args['d_model'])
) for el in args['e_layers']]
self.encoder = EncoderStack(encoders, inp_lens)
# Decoder
self.decoder = Decoder(
[
DecoderLayer(
AttentionLayer(Attn(True, args['factor'], attention_dropout=args['dropout'], output_attention=False),
args['d_model'], args['n_heads'], mix=args['mix']),
AttentionLayer(FullAttention(False, args['factor'], attention_dropout=args['dropout'], output_attention=False),
args['d_model'], args['n_heads'], mix=False),
args['d_model'],
args['d_ff'],
dropout=args['dropout'],
activation=args['activation'],
)
for l in range(args['d_layers'])
],
norm_layer=torch.nn.LayerNorm(args['d_model'])
)
self.projection = nn.Linear(args['d_model'], args['c_out'], bias=True)
def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec,
enc_self_mask=None, dec_self_mask=None, dec_enc_mask=None):
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=enc_self_mask)
dec_out = self.dec_embedding(x_dec, x_mark_dec)
dec_out = self.decoder(dec_out, enc_out, x_mask=dec_self_mask, cross_mask=dec_enc_mask)
dec_out = self.projection(dec_out)
if self.output_attention:
return dec_out[:,-self.pred_len:,:], attns
else:
return dec_out[:,-self.pred_len:,:] # [B, L, D]