26 lines
790 B
Python
26 lines
790 B
Python
# model/Informer/head.py
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
class TemporalProjectionHead(nn.Module):
|
|
"""
|
|
Project along temporal dimension
|
|
[B, L, D] -> [B, pred_len, C]
|
|
"""
|
|
|
|
def __init__(self, d_model, pred_len, c_out):
|
|
super().__init__()
|
|
self.temporal_proj = nn.Linear(1, pred_len)
|
|
self.channel_proj = nn.Linear(d_model, c_out)
|
|
|
|
def forward(self, x):
|
|
# x: [B, L, D]
|
|
# Average over the sequence dimension and then project
|
|
x = x.mean(dim=1, keepdim=True) # [B, 1, D]
|
|
x = x.transpose(1, 2) # [B, D, 1]
|
|
x = self.temporal_proj(x) # [B, D, pred_len]
|
|
x = x.transpose(1, 2) # [B, pred_len, D]
|
|
x = self.channel_proj(x) # [B, pred_len, C]
|
|
return x
|