TrafficWheel/model/iTransformer/iTransformer.py

43 lines
1.8 KiB
Python

import torch
import torch.nn as nn
from model.iTransformer.layers.Transformer_EncDec import Encoder, EncoderLayer
from model.iTransformer.layers.SelfAttn import FullAttention, AttentionLayer
from model.iTransformer.layers.Embed import DataEmbedding_inverted
class iTransformer(nn.Module):
"""
Paper link: https://arxiv.org/abs/2310.06625
"""
def __init__(self, args):
super(iTransformer, self).__init__()
self.pred_len = args['pred_len']
# Embedding
self.enc_embedding = DataEmbedding_inverted(args['seq_len'], args['d_model'], args['dropout'])
# Encoder-only architecture
self.encoder = Encoder(
[
EncoderLayer(
AttentionLayer(
FullAttention(False, attention_dropout=args['dropout'],
output_attention=args['output_attention']), args['d_model'], args['n_heads']),
args['d_model'],
args['d_ff'],
dropout=args['dropout'],
activation=args['activation']
) for l in range(args['e_layers'])
],
norm_layer=torch.nn.LayerNorm(args['d_model'])
)
self.projector = nn.Linear(args['d_model'], args['pred_len'], bias=True)
def forecast(self, x_enc, x_mark_enc):
_, _, N = x_enc.shape # B, T, C
enc_out = self.enc_embedding(x_enc, x_mark_enc)
enc_out, attns = self.encoder(enc_out, attn_mask=None)
dec_out = self.projector(enc_out).permute(0, 2, 1)[:, :, :N] # filter the covariates
return dec_out, attns
def forward(self, x_enc, x_mark_enc=None):
dec_out, attns = self.forecast(x_enc, x_mark_enc)
return dec_out[:, -self.pred_len:, :] # [B, T, C]