AEPSA v0.1

This commit is contained in:
czzhangheng 2025-11-24 21:53:01 +08:00
parent 9c50c30918
commit 7c984c6fd2
1 changed files with 1 additions and 2 deletions

View File

@ -227,11 +227,10 @@ class AEPSA(nn.Module):
""" """
x = x[..., :1] x = x[..., :1]
x_enc = rearrange(x, 'b t n c -> b n c t') x_enc = rearrange(x, 'b t n c -> b n c t')
# 原版Patch
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
# 应用图增强编码器(自动生成图结构) # 应用图增强编码器(自动生成图结构)
graph_enhanced = self.graph_encoder(enc_out) graph_enhanced = self.graph_encoder(enc_out)
# 保持相同的维度
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model] # 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1) enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
enc_out = self.feature_fusion(enc_out) enc_out = self.feature_fusion(enc_out)