# model/Informer/head.py import torch import torch.nn as nn class TemporalProjectionHead(nn.Module): """ Project along temporal dimension [B, L, D] -> [B, pred_len, C] """ def __init__(self, d_model, pred_len, c_out): super().__init__() self.temporal_proj = nn.Linear(1, pred_len) self.channel_proj = nn.Linear(d_model, c_out) def forward(self, x): # x: [B, L, D] # Average over the sequence dimension and then project x = x.mean(dim=1, keepdim=True) # [B, 1, D] x = x.transpose(1, 2) # [B, D, 1] x = self.temporal_proj(x) # [B, D, pred_len] x = x.transpose(1, 2) # [B, pred_len, D] x = self.channel_proj(x) # [B, pred_len, C] return x