AEPSA v0.1
This commit is contained in:
parent
9c50c30918
commit
7c984c6fd2
|
|
@ -227,11 +227,10 @@ class AEPSA(nn.Module):
|
||||||
"""
|
"""
|
||||||
x = x[..., :1]
|
x = x[..., :1]
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
|
# 原版Patch
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
||||||
# 应用图增强编码器(自动生成图结构)
|
# 应用图增强编码器(自动生成图结构)
|
||||||
graph_enhanced = self.graph_encoder(enc_out)
|
graph_enhanced = self.graph_encoder(enc_out)
|
||||||
# 保持相同的维度
|
|
||||||
|
|
||||||
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
# 特征融合 - 现在两个张量都是三维的 [B, N, d_model]
|
||||||
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
enc_out = torch.cat([enc_out, graph_enhanced], dim=-1)
|
||||||
enc_out = self.feature_fusion(enc_out)
|
enc_out = self.feature_fusion(enc_out)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue