import torch import torch.nn as nn import torch.nn.functional as F class EXP(nn.Module): def __init__(self, args): super().__init__() self.horizon = args['horizon'] self.output_dim = args['output_dim'] hidden_dim = args.get('hidden_dim', 128) self.encoder = nn.Sequential( nn.Linear(in_features=12, out_features=hidden_dim), nn.ReLU(), nn.Dropout(0.1), ) self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim) def forward(self, x): x_main = x[..., 0] # (B, T, N) B, T, N = x_main.shape h_in = x_main.permute(0, 2, 1).reshape(B * N, T) h = self.encoder(h_in) # (B*N, hidden_dim) out_flat = self.decoder(h) # (B*N, horizon * output_dim) out = out_flat.view(B, N, self.horizon, self.output_dim) \ .permute(0, 2, 1, 3) return out