TrafficWheel/model/EXP/EXP3_easy.py

28 lines
952 B
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
class EXP(nn.Module):
def __init__(self, args):
super().__init__()
self.horizon = args['horizon']
self.output_dim = args['output_dim']
hidden_dim = args.get('hidden_dim', 128)
self.encoder = nn.Sequential(
nn.Linear(in_features=12, out_features=hidden_dim),
nn.ReLU(),
nn.Dropout(0.1),
)
self.decoder = nn.Linear(hidden_dim, self.horizon * self.output_dim)
def forward(self, x):
x_main = x[..., 0] # (B, T, N)
B, T, N = x_main.shape
h_in = x_main.permute(0, 2, 1).reshape(B * N, T)
h = self.encoder(h_in) # (B*N, hidden_dim)
out_flat = self.decoder(h) # (B*N, horizon * output_dim)
out = out_flat.view(B, N, self.horizon, self.output_dim) \
.permute(0, 2, 1, 3)
return out