From f899f50b163c9c5b4c2882a00c432bd0125eaa11 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sat, 6 Dec 2025 19:48:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8F=98=E9=87=8F=E5=90=8D=E6=9B=B4=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- model/AEPSA/aepsav3.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/model/AEPSA/aepsav3.py b/model/AEPSA/aepsav3.py index 6a579b6..99c6748 100644 --- a/model/AEPSA/aepsav3.py +++ b/model/AEPSA/aepsav3.py @@ -190,7 +190,7 @@ class AEPSA(nn.Module): # 图编码 H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] X_t_1 = self.graph_projection(H_t) # [B,N,d_model] - enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] + X_enc = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] # 词嵌入处理 self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -198,9 +198,9 @@ class AEPSA(nn.Module): source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] # 重编程与预测 - enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) - enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm] - dec_out = self.out_mlp(enc_out) # [B,N,pred_len] + X_enc = self.reprogramming_layer(X_enc, source_embeddings, source_embeddings) + X_enc = self.gpts(inputs_embeds=X_enc).last_hidden_state # [B,N,d_llm] + dec_out = self.out_mlp(X_enc) # [B,N,pred_len] # 维度调整 outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]