TrafficWheel/model/AEPSA/Chebyshev+Laplacian_constru...

122 lines
4.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import torch.nn as nn
import torch.nn.functional as F
class ChebSpectralReconstructor(nn.Module):
"""
Spectral reconstruction using Chebyshev polynomial approximation.
Input:
X: Tensor of shape [N, T, d] (node features over time)
A: Adjacency matrix [N, N]
Output:
X_tilde: Tensor [N, T, d*(K+2)] for downstream model
"""
def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'):
super().__init__()
self.K = K
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# Spectral coefficients α_k (learnable)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# Propagation weights W_k (learnable)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.to(device)
def compute_normalized_laplacian(self, A):
"""Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}"""
I = torch.eye(A.size(0), device=A.device)
D = torch.diag(torch.sum(A, dim=1))
D_inv_sqrt = torch.pow(D, -0.5)
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
L = I - D_inv_sqrt @ A @ D_inv_sqrt
return L
def chebyshev_polynomials(self, L_tilde, X_t):
"""Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively"""
T_k_list = [X_t]
if self.K >= 1:
T_k_list.append(L_tilde @ X_t)
for k in range(2, self.K + 1):
T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2])
return T_k_list
def forward_one_step(self, X_t, L):
"""Compute one-step reconstruction of X_{t+1}"""
lambda_max = torch.linalg.eigvalsh(L).max().real
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
T_k_list = self.chebyshev_polynomials(L_tilde, X_t)
H_list = []
for k in range(self.K + 1):
H_k = T_k_list[k] @ self.W[k]
H_list.append(H_k)
H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1]
H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1)
X_hat_next = torch.tanh(H_weighted)
return X_hat_next, H_list
def forward_sequence(self, X, A):
"""
Compute decoupled multi-scale representation over all time steps.
X: [N, T, d], A: [N, N]
"""
L = self.compute_normalized_laplacian(A)
N, T, d = X.shape
multi_scale_list = [X[:, 0, :]] # include X_1
for t in range(T - 1):
X_hat_next, H_list = self.forward_one_step(X[:, t, :], L)
X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1)
multi_scale_list.append(X_tilde_next)
X_tilde = torch.stack(multi_scale_list, dim=1)
# Shape: [N, T, hidden_dim*(K+2)]
return X_tilde
def loss_fn(self, X_true, X_pred):
return F.mse_loss(X_true, X_pred)
def fit(self, X, A, num_epochs=200, verbose=True):
"""Optimization process to learn α_k and W_k"""
L = self.compute_normalized_laplacian(A)
for epoch in range(num_epochs):
total_loss = 0
for t in range(X.size(1) - 1):
X_t = X[:, t, :]
X_true_next = X[:, t + 1, :]
X_pred_next, _ = self.forward_one_step(X_t, L)
loss = self.loss_fn(X_true_next, X_pred_next)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if verbose and (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}")
print("Training complete.")
# ---------------------------
# Example usage
# ---------------------------
if __name__ == "__main__":
N, T, d = 20, 10, 16
A = torch.rand(N, N)
A = (A + A.T) / 2 # symmetrize
X = torch.randn(N, T, d)
model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu')
model.fit(X, A, num_epochs=50)
X_tilde = model.forward_sequence(X, A)
print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]