TrafficWheel/model/Informer/head.py

26 lines
790 B
Python

# model/Informer/head.py
import torch
import torch.nn as nn
class TemporalProjectionHead(nn.Module):
"""
Project along temporal dimension
[B, L, D] -> [B, pred_len, C]
"""
def __init__(self, d_model, pred_len, c_out):
super().__init__()
self.temporal_proj = nn.Linear(1, pred_len)
self.channel_proj = nn.Linear(d_model, c_out)
def forward(self, x):
# x: [B, L, D]
# Average over the sequence dimension and then project
x = x.mean(dim=1, keepdim=True) # [B, 1, D]
x = x.transpose(1, 2) # [B, D, 1]
x = self.temporal_proj(x) # [B, D, pred_len]
x = x.transpose(1, 2) # [B, pred_len, D]
x = self.channel_proj(x) # [B, pred_len, C]
return x