Compare commits
2 Commits
38431fb4c1
...
15f083c3d9
| Author | SHA1 | Date |
|---|---|---|
|
|
15f083c3d9 | |
|
|
095e8c60dc |
|
|
@ -1,7 +1,7 @@
|
||||||
basic:
|
basic:
|
||||||
dataset: "PEMS-BAY"
|
dataset: "PEMS-BAY"
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:0"
|
device : "cuda:1"
|
||||||
model: "REPST"
|
model: "REPST"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,121 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
class ChebSpectralReconstructor(nn.Module):
|
||||||
|
"""
|
||||||
|
Spectral reconstruction using Chebyshev polynomial approximation.
|
||||||
|
Input:
|
||||||
|
X: Tensor of shape [N, T, d] (node features over time)
|
||||||
|
A: Adjacency matrix [N, N]
|
||||||
|
Output:
|
||||||
|
X_tilde: Tensor [N, T, d*(K+2)] for downstream model
|
||||||
|
"""
|
||||||
|
def __init__(self, K=3, in_dim=16, hidden_dim=32, lr=1e-3, device='cpu'):
|
||||||
|
super().__init__()
|
||||||
|
self.K = K
|
||||||
|
self.in_dim = in_dim
|
||||||
|
self.hidden_dim = hidden_dim
|
||||||
|
self.device = device
|
||||||
|
|
||||||
|
# Spectral coefficients α_k (learnable)
|
||||||
|
self.alpha = nn.Parameter(torch.randn(K + 1, 1))
|
||||||
|
|
||||||
|
# Propagation weights W_k (learnable)
|
||||||
|
self.W = nn.ParameterList([
|
||||||
|
nn.Parameter(torch.randn(in_dim, hidden_dim)) for _ in range(K + 1)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.optimizer = torch.optim.Adam(self.parameters(), lr=lr)
|
||||||
|
self.to(device)
|
||||||
|
|
||||||
|
def compute_normalized_laplacian(self, A):
|
||||||
|
"""Compute normalized Laplacian L = I - D^{-1/2} A D^{-1/2}"""
|
||||||
|
I = torch.eye(A.size(0), device=A.device)
|
||||||
|
D = torch.diag(torch.sum(A, dim=1))
|
||||||
|
D_inv_sqrt = torch.pow(D, -0.5)
|
||||||
|
D_inv_sqrt[torch.isinf(D_inv_sqrt)] = 0.0
|
||||||
|
L = I - D_inv_sqrt @ A @ D_inv_sqrt
|
||||||
|
return L
|
||||||
|
|
||||||
|
def chebyshev_polynomials(self, L_tilde, X_t):
|
||||||
|
"""Compute [T_0(L_tilde)X_t, ..., T_K(L_tilde)X_t] recursively"""
|
||||||
|
T_k_list = [X_t]
|
||||||
|
if self.K >= 1:
|
||||||
|
T_k_list.append(L_tilde @ X_t)
|
||||||
|
for k in range(2, self.K + 1):
|
||||||
|
T_k_list.append(2 * L_tilde @ T_k_list[-1] - T_k_list[-2])
|
||||||
|
return T_k_list
|
||||||
|
|
||||||
|
def forward_one_step(self, X_t, L):
|
||||||
|
"""Compute one-step reconstruction of X_{t+1}"""
|
||||||
|
lambda_max = torch.linalg.eigvalsh(L).max().real
|
||||||
|
L_tilde = (2.0 / lambda_max) * L - torch.eye(L.size(0), device=L.device)
|
||||||
|
T_k_list = self.chebyshev_polynomials(L_tilde, X_t)
|
||||||
|
|
||||||
|
H_list = []
|
||||||
|
for k in range(self.K + 1):
|
||||||
|
H_k = T_k_list[k] @ self.W[k]
|
||||||
|
H_list.append(H_k)
|
||||||
|
H_stack = torch.stack(H_list, dim=-1) # [N, hidden_dim, K+1]
|
||||||
|
H_weighted = torch.sum(H_stack * self.alpha.view(1, 1, -1), dim=-1)
|
||||||
|
X_hat_next = torch.tanh(H_weighted)
|
||||||
|
return X_hat_next, H_list
|
||||||
|
|
||||||
|
def forward_sequence(self, X, A):
|
||||||
|
"""
|
||||||
|
Compute decoupled multi-scale representation over all time steps.
|
||||||
|
X: [N, T, d], A: [N, N]
|
||||||
|
"""
|
||||||
|
L = self.compute_normalized_laplacian(A)
|
||||||
|
N, T, d = X.shape
|
||||||
|
multi_scale_list = [X[:, 0, :]] # include X_1
|
||||||
|
|
||||||
|
for t in range(T - 1):
|
||||||
|
X_hat_next, H_list = self.forward_one_step(X[:, t, :], L)
|
||||||
|
X_tilde_next = torch.cat(H_list + [X_hat_next], dim=-1)
|
||||||
|
multi_scale_list.append(X_tilde_next)
|
||||||
|
|
||||||
|
X_tilde = torch.stack(multi_scale_list, dim=1)
|
||||||
|
# Shape: [N, T, hidden_dim*(K+2)]
|
||||||
|
return X_tilde
|
||||||
|
|
||||||
|
def loss_fn(self, X_true, X_pred):
|
||||||
|
return F.mse_loss(X_true, X_pred)
|
||||||
|
|
||||||
|
def fit(self, X, A, num_epochs=200, verbose=True):
|
||||||
|
"""Optimization process to learn α_k and W_k"""
|
||||||
|
L = self.compute_normalized_laplacian(A)
|
||||||
|
for epoch in range(num_epochs):
|
||||||
|
total_loss = 0
|
||||||
|
for t in range(X.size(1) - 1):
|
||||||
|
X_t = X[:, t, :]
|
||||||
|
X_true_next = X[:, t + 1, :]
|
||||||
|
X_pred_next, _ = self.forward_one_step(X_t, L)
|
||||||
|
loss = self.loss_fn(X_true_next, X_pred_next)
|
||||||
|
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
self.optimizer.step()
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if verbose and (epoch + 1) % 10 == 0:
|
||||||
|
print(f"Epoch [{epoch+1}/{num_epochs}] | Loss: {total_loss:.6f}")
|
||||||
|
|
||||||
|
print("Training complete.")
|
||||||
|
|
||||||
|
# ---------------------------
|
||||||
|
# Example usage
|
||||||
|
# ---------------------------
|
||||||
|
if __name__ == "__main__":
|
||||||
|
N, T, d = 20, 10, 16
|
||||||
|
A = torch.rand(N, N)
|
||||||
|
A = (A + A.T) / 2 # symmetrize
|
||||||
|
X = torch.randn(N, T, d)
|
||||||
|
|
||||||
|
model = ChebSpectralReconstructor(K=3, in_dim=d, hidden_dim=16, lr=1e-3, device='cpu')
|
||||||
|
model.fit(X, A, num_epochs=50)
|
||||||
|
|
||||||
|
X_tilde = model.forward_sequence(X, A)
|
||||||
|
print("Final embedding shape:", X_tilde.shape) # [N, T, hidden_dim*(K+2)]
|
||||||
|
|
||||||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
# 768,64,25
|
# 为什么没permute后reshape?
|
||||||
|
x = x.permute(0, 1, 4, 3, 2)
|
||||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
x = self.confusion_layer(x)
|
x = self.confusion_layer(x)
|
||||||
return x.reshape(b, n, -1)
|
return x.reshape(b, n, -1)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
# 768,64,25
|
# 为什么没permute后reshape?
|
||||||
|
x = x.permute(0, 1, 4, 3, 2)
|
||||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
x = self.confusion_layer(x)
|
x = self.confusion_layer(x)
|
||||||
return x.reshape(b, n, -1)
|
return x.reshape(b, n, -1)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from tkinter import Y
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
|
@ -77,4 +78,25 @@ class repst(nn.Module):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
configs = {
|
||||||
|
'device': 'cuda:0',
|
||||||
|
'pred_len': 24,
|
||||||
|
'seq_len': 24,
|
||||||
|
'patch_len': 6,
|
||||||
|
'stride': 7,
|
||||||
|
'dropout': 0.2,
|
||||||
|
'gpt_layers': 9,
|
||||||
|
'd_ff': 128,
|
||||||
|
'gpt_path': './GPT-2',
|
||||||
|
'd_model': 64,
|
||||||
|
'n_heads': 1,
|
||||||
|
'input_dim': 1
|
||||||
|
}
|
||||||
|
model = repst(configs)
|
||||||
|
x = torch.randn(16, 24, 325, 1)
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
|
print(y.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue