Compare commits

..

No commits in common. "15989181126cecd49b4b7a5673f5f29b519d2586" and "865c5a30823c507b486fddb8faf530142e584aa1" have entirely different histories.

2 changed files with 4 additions and 6 deletions

View File

@ -19,11 +19,9 @@ data:
val_ratio: 0.2 val_ratio: 0.2
model: model:
chebyshev_order: 3
d_ff: 128 d_ff: 128
d_model: 64 d_model: 64
dropout: 0.2 dropout: 0.2
graph_hidden_dim: 32
gpt_layers: 9 gpt_layers: 9
gpt_path: ./GPT-2 gpt_path: ./GPT-2
input_dim: 1 input_dim: 1

View File

@ -190,7 +190,7 @@ class AEPSA(nn.Module):
# 图编码 # 图编码
H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)] H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)]
X_t_1 = self.graph_projection(H_t) # [B,N,d_model] X_t_1 = self.graph_projection(H_t) # [B,N,d_model]
X_enc = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)] enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)]
# 词嵌入处理 # 词嵌入处理
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
@ -198,9 +198,9 @@ class AEPSA(nn.Module):
source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm] source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm]
# 重编程与预测 # 重编程与预测
X_enc = self.reprogramming_layer(X_enc, source_embeddings, source_embeddings) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
X_enc = self.gpts(inputs_embeds=X_enc).last_hidden_state # [B,N,d_llm] enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm]
dec_out = self.out_mlp(X_enc) # [B,N,pred_len] dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
# 维度调整 # 维度调整
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1] outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]