import torch import torch.nn as nn import torch.nn.functional as F class ChebSpectralReconstructor(nn.Module): """ Spectral reconstruction using Chebyshev polynomial approximation. Input: X: Tensor of shape [N, T, d] (node features over time) A: Adjacency matrix [N, N] Output: X_tilde: Tensor [N, T, d*(K+2)] for downstream model """ def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'): super().__init__() self.K = K self.in_dim = in_dim self.hidden_dim = hidden_dim self.device = device # Spectral coefficients α_k (learnable) self.alpha = nn.Parameter(torch.randn(K + 1, 1)) # Propagation weights W_k (learnable) self.W = nn.ParameterList([ nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1) ]) self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) self.to(device) def compute_normalized_laplacian(self, A): """Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}""" I = torch.eye(A.size(0), device=A.device) D = torch.diag(torch.sum(A, dim=1)) D_inv_sqrt = torch.pow(D, -0.5) D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0 L = I - D_inv_sqrt @ A @ D_inv_sqrt return L def chebyshev_polynomials(self, L_tilde, X_t): """Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively""" T_k_list = [X_t] if self.K >= 1: T_k_list.append(L_tilde @ X_t) for k in range(2, self.K + 1): T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2]) return T_k_list def forward_one_step(self, X_t, L): """Compute one-step reconstruction of X_{t+1}""" lambda_max = torch.linalg.eigvalsh(L).max().real L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) T_k_list = self.chebyshev_polynomials(L_tilde, X_t) H_list = [] for k in range(self.K + 1): H_k = T_k_list[k] @ self.W[k] H_list.append(H_k) H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1] H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1) X_hat_next = torch.tanh(H_weighted) return X_hat_next, H_list def forward_sequence(self, X, A): """ Compute decoupled multi-scale representation over all time steps. X: [N, T, d], A: [N, N] """ L = self.compute_normalized_laplacian(A) N, T, d = X.shape multi_scale_list = [X[:, 0, :]] # include X_1 for t in range(T - 1): X_hat_next, H_list = self.forward_one_step(X[:, t, :], L) X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1) multi_scale_list.append(X_tilde_next) X_tilde = torch.stack(multi_scale_list, dim=1) # Shape: [N, T, hidden_dim*(K+2)] return X_tilde def loss_fn(self, X_true, X_pred): return F.mse_loss(X_true, X_pred) def fit(self, X, A, num_epochs=200, verbose=True): """Optimization process to learn α_k and W_k""" L = self.compute_normalized_laplacian(A) for epoch in range(num_epochs): total_loss = 0 for t in range(X.size(1) - 1): X_t = X[:, t, :] X_true_next = X[:, t + 1, :] X_pred_next, _ = self.forward_one_step(X_t, L) loss = self.loss_fn(X_true_next, X_pred_next) self.optimizer.zero_grad() loss.backward() self.optimizer.step() total_loss += loss.item() if verbose and (epoch + 1) % 10 == 0: print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}") print("Training complete.") # --------------------------- # Example usage # --------------------------- if __name__ == "__main__": N, T, d = 20, 10, 16 A = torch.rand(N, N) A = (A + A.T) / 2 # symmetrize X = torch.randn(N, T, d) model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu') model.fit(X, A, num_epochs=50) X_tilde = model.forward_sequence(X, A) print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]