From 095e8c60dc5f5371ac71fb439bd586666f49f3b4 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 12 Nov 2025 16:40:09 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=B2=A1permute=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../AEPSA/Chebyshev+Laplacian_construction.py | 121 ++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 model/AEPSA/Chebyshev+Laplacian_construction.py diff --git a/model/AEPSA/Chebyshev+Laplacian_construction.py b/model/AEPSA/Chebyshev+Laplacian_construction.py new file mode 100644 index 0000000..a697c67 --- /dev/null +++ b/model/AEPSA/Chebyshev+Laplacian_construction.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class ChebSpectralReconstructor(nn.Module): + """ + Spectral reconstruction using Chebyshev polynomial approximation. + Input: + X: Tensor of shape [N, T, d] (node features over time) + A: Adjacency matrix [N, N] + Output: + X_tilde: Tensor [N, T, d*(K+2)] for downstream model + """ + def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'): + super().__init__() + self.K = K + self.in_dim = in_dim + self.hidden_dim = hidden_dim + self.device = device + + # Spectral coefficients α_k (learnable) + self.alpha = nn.Parameter(torch.randn(K + 1, 1)) + + # Propagation weights W_k (learnable) + self.W = nn.ParameterList([ + nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1) + ]) + + self.optimizer = torch.optim.Adam(self.parameters(), lr=lr) + self.to(device) + + def compute_normalized_laplacian(self, A): + """Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}""" + I = torch.eye(A.size(0), device=A.device) + D = torch.diag(torch.sum(A, dim=1)) + D_inv_sqrt = torch.pow(D, -0.5) + D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0 + L = I - D_inv_sqrt @ A @ D_inv_sqrt + return L + + def chebyshev_polynomials(self, L_tilde, X_t): + """Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively""" + T_k_list = [X_t] + if self.K >= 1: + T_k_list.append(L_tilde @ X_t) + for k in range(2, self.K + 1): + T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2]) + return T_k_list + + def forward_one_step(self, X_t, L): + """Compute one-step reconstruction of X_{t+1}""" + lambda_max = torch.linalg.eigvalsh(L).max().real + L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device) + T_k_list = self.chebyshev_polynomials(L_tilde, X_t) + + H_list = [] + for k in range(self.K + 1): + H_k = T_k_list[k] @ self.W[k] + H_list.append(H_k) + H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1] + H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1) + X_hat_next = torch.tanh(H_weighted) + return X_hat_next, H_list + + def forward_sequence(self, X, A): + """ + Compute decoupled multi-scale representation over all time steps. + X: [N, T, d], A: [N, N] + """ + L = self.compute_normalized_laplacian(A) + N, T, d = X.shape + multi_scale_list = [X[:, 0, :]] # include X_1 + + for t in range(T - 1): + X_hat_next, H_list = self.forward_one_step(X[:, t, :], L) + X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1) + multi_scale_list.append(X_tilde_next) + + X_tilde = torch.stack(multi_scale_list, dim=1) + # Shape: [N, T, hidden_dim*(K+2)] + return X_tilde + + def loss_fn(self, X_true, X_pred): + return F.mse_loss(X_true, X_pred) + + def fit(self, X, A, num_epochs=200, verbose=True): + """Optimization process to learn α_k and W_k""" + L = self.compute_normalized_laplacian(A) + for epoch in range(num_epochs): + total_loss = 0 + for t in range(X.size(1) - 1): + X_t = X[:, t, :] + X_true_next = X[:, t + 1, :] + X_pred_next, _ = self.forward_one_step(X_t, L) + loss = self.loss_fn(X_true_next, X_pred_next) + + self.optimizer.zero_grad() + loss.backward() + self.optimizer.step() + total_loss += loss.item() + + if verbose and (epoch + 1) % 10 == 0: + print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}") + + print("Training complete.") + +# --------------------------- +# Example usage +# --------------------------- +if __name__ == "__main__": + N, T, d = 20, 10, 16 + A = torch.rand(N, N) + A = (A + A.T) / 2 # symmetrize + X = torch.randn(N, T, d) + + model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu') + model.fit(X, A, num_epochs=50) + + X_tilde = model.forward_sequence(X, A) + print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)] +