28 lines
952 B
Python
28 lines
952 B
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class EXP(nn.Module):
|
|
|
|
def __init__(self, args):
|
|
super().__init__()
|
|
self.horizon = args['horizon']
|
|
self.output_dim = args['output_dim']
|
|
hidden_dim = args.get('hidden_dim', 128)
|
|
self.encoder = nn.Sequential(
|
|
nn.Linear(in_features=12, out_features=hidden_dim),
|
|
nn.ReLU(),
|
|
nn.Dropout(0.1),
|
|
)
|
|
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
|
|
|
|
def forward(self, x):
|
|
x_main = x[..., 0] # (B, T, N)
|
|
B, T, N = x_main.shape
|
|
h_in = x_main.permute(0, 2, 1).reshape(B * N, T)
|
|
h = self.encoder(h_in) # (B*N, hidden_dim)
|
|
out_flat = self.decoder(h) # (B*N, horizon * output_dim)
|
|
out = out_flat.view(B, N, self.horizon, self.output_dim) \
|
|
.permute(0, 2, 1, 3)
|
|
return out
|