修复没permute的bug
This commit is contained in:
parent
38431fb4c1
commit
095e8c60dc
|
|
@ -0,0 +1,121 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
class ChebSpectralReconstructor(nn.Module):
|
||||
"""
|
||||
Spectral reconstruction using Chebyshev polynomial approximation.
|
||||
Input:
|
||||
X: Tensor of shape [N, T, d] (node features over time)
|
||||
A: Adjacency matrix [N, N]
|
||||
Output:
|
||||
X_tilde: Tensor [N, T, d*(K+2)] for downstream model
|
||||
"""
|
||||
def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'):
|
||||
super().__init__()
|
||||
self.K = K
|
||||
self.in_dim = in_dim
|
||||
self.hidden_dim = hidden_dim
|
||||
self.device = device
|
||||
|
||||
# Spectral coefficients α_k (learnable)
|
||||
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
|
||||
|
||||
# Propagation weights W_k (learnable)
|
||||
self.W = nn.ParameterList([
|
||||
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
|
||||
])
|
||||
|
||||
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
||||
self.to(device)
|
||||
|
||||
def compute_normalized_laplacian(self, A):
|
||||
"""Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}"""
|
||||
I = torch.eye(A.size(0), device=A.device)
|
||||
D = torch.diag(torch.sum(A, dim=1))
|
||||
D_inv_sqrt = torch.pow(D, -0.5)
|
||||
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
|
||||
L = I - D_inv_sqrt @ A @ D_inv_sqrt
|
||||
return L
|
||||
|
||||
def chebyshev_polynomials(self, L_tilde, X_t):
|
||||
"""Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively"""
|
||||
T_k_list = [X_t]
|
||||
if self.K >= 1:
|
||||
T_k_list.append(L_tilde @ X_t)
|
||||
for k in range(2, self.K + 1):
|
||||
T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2])
|
||||
return T_k_list
|
||||
|
||||
def forward_one_step(self, X_t, L):
|
||||
"""Compute one-step reconstruction of X_{t+1}"""
|
||||
lambda_max = torch.linalg.eigvalsh(L).max().real
|
||||
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
|
||||
T_k_list = self.chebyshev_polynomials(L_tilde, X_t)
|
||||
|
||||
H_list = []
|
||||
for k in range(self.K + 1):
|
||||
H_k = T_k_list[k] @ self.W[k]
|
||||
H_list.append(H_k)
|
||||
H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1]
|
||||
H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1)
|
||||
X_hat_next = torch.tanh(H_weighted)
|
||||
return X_hat_next, H_list
|
||||
|
||||
def forward_sequence(self, X, A):
|
||||
"""
|
||||
Compute decoupled multi-scale representation over all time steps.
|
||||
X: [N, T, d], A: [N, N]
|
||||
"""
|
||||
L = self.compute_normalized_laplacian(A)
|
||||
N, T, d = X.shape
|
||||
multi_scale_list = [X[:, 0, :]] # include X_1
|
||||
|
||||
for t in range(T - 1):
|
||||
X_hat_next, H_list = self.forward_one_step(X[:, t, :], L)
|
||||
X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1)
|
||||
multi_scale_list.append(X_tilde_next)
|
||||
|
||||
X_tilde = torch.stack(multi_scale_list, dim=1)
|
||||
# Shape: [N, T, hidden_dim*(K+2)]
|
||||
return X_tilde
|
||||
|
||||
def loss_fn(self, X_true, X_pred):
|
||||
return F.mse_loss(X_true, X_pred)
|
||||
|
||||
def fit(self, X, A, num_epochs=200, verbose=True):
|
||||
"""Optimization process to learn α_k and W_k"""
|
||||
L = self.compute_normalized_laplacian(A)
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
for t in range(X.size(1) - 1):
|
||||
X_t = X[:, t, :]
|
||||
X_true_next = X[:, t + 1, :]
|
||||
X_pred_next, _ = self.forward_one_step(X_t, L)
|
||||
loss = self.loss_fn(X_true_next, X_pred_next)
|
||||
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
self.optimizer.step()
|
||||
total_loss += loss.item()
|
||||
|
||||
if verbose and (epoch + 1) % 10 == 0:
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}")
|
||||
|
||||
print("Training complete.")
|
||||
|
||||
# ---------------------------
|
||||
# Example usage
|
||||
# ---------------------------
|
||||
if __name__ == "__main__":
|
||||
N, T, d = 20, 10, 16
|
||||
A = torch.rand(N, N)
|
||||
A = (A + A.T) / 2 # symmetrize
|
||||
X = torch.randn(N, T, d)
|
||||
|
||||
model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu')
|
||||
model.fit(X, A, num_epochs=50)
|
||||
|
||||
X_tilde = model.forward_sequence(X, A)
|
||||
print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]
|
||||
|
||||
Loading…
Reference in New Issue