修复没permute的bug

This commit is contained in:
czzhangheng 2025-11-12 16:40:09 +08:00
parent 38431fb4c1
commit 095e8c60dc
1 changed files with 121 additions and 0 deletions

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChebSpectralReconstructor(nn.Module):
"""
Spectral reconstruction using Chebyshev polynomial approximation.
Input:
X: Tensor of shape [N, T, d] (node features over time)
A: Adjacency matrix [N, N]
Output:
X_tilde: Tensor [N, T, d*(K+2)] for downstream model
"""
def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'):
super().__init__()
self.K = K
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# Spectral coefficients α_k (learnable)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# Propagation weights W_k (learnable)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.to(device)
def compute_normalized_laplacian(self, A):
"""Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}"""
I = torch.eye(A.size(0), device=A.device)
D = torch.diag(torch.sum(A, dim=1))
D_inv_sqrt = torch.pow(D, -0.5)
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
L = I - D_inv_sqrt @ A @ D_inv_sqrt
return L
def chebyshev_polynomials(self, L_tilde, X_t):
"""Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively"""
T_k_list = [X_t]
if self.K >= 1:
T_k_list.append(L_tilde @ X_t)
for k in range(2, self.K + 1):
T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2])
return T_k_list
def forward_one_step(self, X_t, L):
"""Compute one-step reconstruction of X_{t+1}"""
lambda_max = torch.linalg.eigvalsh(L).max().real
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
T_k_list = self.chebyshev_polynomials(L_tilde, X_t)
H_list = []
for k in range(self.K + 1):
H_k = T_k_list[k] @ self.W[k]
H_list.append(H_k)
H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1]
H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1)
X_hat_next = torch.tanh(H_weighted)
return X_hat_next, H_list
def forward_sequence(self, X, A):
"""
Compute decoupled multi-scale representation over all time steps.
X: [N, T, d], A: [N, N]
"""
L = self.compute_normalized_laplacian(A)
N, T, d = X.shape
multi_scale_list = [X[:, 0, :]] # include X_1
for t in range(T - 1):
X_hat_next, H_list = self.forward_one_step(X[:, t, :], L)
X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1)
multi_scale_list.append(X_tilde_next)
X_tilde = torch.stack(multi_scale_list, dim=1)
# Shape: [N, T, hidden_dim*(K+2)]
return X_tilde
def loss_fn(self, X_true, X_pred):
return F.mse_loss(X_true, X_pred)
def fit(self, X, A, num_epochs=200, verbose=True):
"""Optimization process to learn α_k and W_k"""
L = self.compute_normalized_laplacian(A)
for epoch in range(num_epochs):
total_loss = 0
for t in range(X.size(1) - 1):
X_t = X[:, t, :]
X_true_next = X[:, t + 1, :]
X_pred_next, _ = self.forward_one_step(X_t, L)
loss = self.loss_fn(X_true_next, X_pred_next)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if verbose and (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}")
print("Training complete.")
# ---------------------------
# Example usage
# ---------------------------
if __name__ == "__main__":
N, T, d = 20, 10, 16
A = torch.rand(N, N)
A = (A + A.T) / 2 # symmetrize
X = torch.randn(N, T, d)
model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu')
model.fit(X, A, num_epochs=50)
X_tilde = model.forward_sequence(X, A)
print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]