Compare commits

...

2 Commits

Author SHA1 Message Date
czzhangheng 15f083c3d9 修复permute的bug 2025-11-12 16:40:13 +08:00
czzhangheng 095e8c60dc 修复没permute的bug 2025-11-12 16:40:09 +08:00
5 changed files with 148 additions and 3 deletions

View File

@ -1,7 +1,7 @@
basic:
dataset: "PEMS-BAY"
mode : "train"
device : "cuda:0"
device : "cuda:1"
model: "REPST"
data:

View File

@ -0,0 +1,121 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChebSpectralReconstructor(nn.Module):
"""
Spectral reconstruction using Chebyshev polynomial approximation.
Input:
X: Tensor of shape [N, T, d] (node features over time)
A: Adjacency matrix [N, N]
Output:
X_tilde: Tensor [N, T, d*(K+2)] for downstream model
"""
def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'):
super().__init__()
self.K = K
self.in_dim = in_dim
self.hidden_dim = hidden_dim
self.device = device
# Spectral coefficients α_k (learnable)
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
# Propagation weights W_k (learnable)
self.W = nn.ParameterList([
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
])
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
self.to(device)
def compute_normalized_laplacian(self, A):
"""Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}"""
I = torch.eye(A.size(0), device=A.device)
D = torch.diag(torch.sum(A, dim=1))
D_inv_sqrt = torch.pow(D, -0.5)
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
L = I - D_inv_sqrt @ A @ D_inv_sqrt
return L
def chebyshev_polynomials(self, L_tilde, X_t):
"""Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively"""
T_k_list = [X_t]
if self.K >= 1:
T_k_list.append(L_tilde @ X_t)
for k in range(2, self.K + 1):
T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2])
return T_k_list
def forward_one_step(self, X_t, L):
"""Compute one-step reconstruction of X_{t+1}"""
lambda_max = torch.linalg.eigvalsh(L).max().real
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
T_k_list = self.chebyshev_polynomials(L_tilde, X_t)
H_list = []
for k in range(self.K + 1):
H_k = T_k_list[k] @ self.W[k]
H_list.append(H_k)
H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1]
H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1)
X_hat_next = torch.tanh(H_weighted)
return X_hat_next, H_list
def forward_sequence(self, X, A):
"""
Compute decoupled multi-scale representation over all time steps.
X: [N, T, d], A: [N, N]
"""
L = self.compute_normalized_laplacian(A)
N, T, d = X.shape
multi_scale_list = [X[:, 0, :]] # include X_1
for t in range(T - 1):
X_hat_next, H_list = self.forward_one_step(X[:, t, :], L)
X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1)
multi_scale_list.append(X_tilde_next)
X_tilde = torch.stack(multi_scale_list, dim=1)
# Shape: [N, T, hidden_dim*(K+2)]
return X_tilde
def loss_fn(self, X_true, X_pred):
return F.mse_loss(X_true, X_pred)
def fit(self, X, A, num_epochs=200, verbose=True):
"""Optimization process to learn α_k and W_k"""
L = self.compute_normalized_laplacian(A)
for epoch in range(num_epochs):
total_loss = 0
for t in range(X.size(1) - 1):
X_t = X[:, t, :]
X_true_next = X[:, t + 1, :]
X_pred_next, _ = self.forward_one_step(X_t, L)
loss = self.loss_fn(X_true_next, X_pred_next)
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
total_loss += loss.item()
if verbose and (epoch + 1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}")
print("Training complete.")
# ---------------------------
# Example usage
# ---------------------------
if __name__ == "__main__":
N, T, d = 20, 10, 16
A = torch.rand(N, N)
A = (A + A.T) / 2 # symmetrize
X = torch.randn(N, T, d)
model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu')
model.fit(X, A, num_epochs=50)
X_tilde = model.forward_sequence(X, A)
print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]

View File

@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
def forward(self, x):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
# 为什么没permute后reshape?
x = x.permute(0, 1, 4, 3, 2)
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)

View File

@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
def forward(self, x):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
# 为什么没permute后reshape?
x = x.permute(0, 1, 4, 3, 2)
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)

View File

@ -1,3 +1,4 @@
from tkinter import Y
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
@ -77,4 +78,25 @@ class repst(nn.Module):
return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)