AEPSA v0.1
This commit is contained in:
parent
9c50c30918
commit
7c984c6fd2
|
|
@ -227,11 +227,10 @@ class AEPSA(nn.Module):
|
|||
"""
|
||||
x = x[..., :1]
|
||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||
# 原版Patch
|
||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
||||
# 应用图增强编码器(自动生成图结构)
|
||||
graph_enhanced = self.graph_encoder(enc_out)
|
||||
# 保持相同的维度
|
||||
|
||||
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
||||
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
||||
enc_out = self.feature_fusion(enc_out)
|
||||
|
|
|
|||
Loading…
Reference in New Issue