Compare commits
2 Commits
865c5a3082
...
1598918112
| Author | SHA1 | Date |
|---|---|---|
|
|
1598918112 | |
|
|
f899f50b16 |
|
|
@ -19,9 +19,11 @@ data:
|
||||||
val_ratio: 0.2
|
val_ratio: 0.2
|
||||||
|
|
||||||
model:
|
model:
|
||||||
|
chebyshev_order: 3
|
||||||
d_ff: 128
|
d_ff: 128
|
||||||
d_model: 64
|
d_model: 64
|
||||||
dropout: 0.2
|
dropout: 0.2
|
||||||
|
graph_hidden_dim: 32
|
||||||
gpt_layers: 9
|
gpt_layers: 9
|
||||||
gpt_path: ./GPT-2
|
gpt_path: ./GPT-2
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ class AEPSA(nn.Module):
|
||||||
# 图编码
|
# 图编码
|
||||||
H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)]
|
H_t = self.graph_encoder(x_enc) # [B,N,1,T] -> [B, N, hidden_dim*(K+1)]
|
||||||
X_t_1 = self.graph_projection(H_t) # [B,N,d_model]
|
X_t_1 = self.graph_projection(H_t) # [B,N,d_model]
|
||||||
enc_out = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)]
|
X_enc = torch.cat([H_t, X_t_1], dim = -1) # [B, N, d_model + hidden_dim*(K+1)]
|
||||||
|
|
||||||
# 词嵌入处理
|
# 词嵌入处理
|
||||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
|
|
@ -198,9 +198,9 @@ class AEPSA(nn.Module):
|
||||||
source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm]
|
source_embeddings = self.word_embeddings[masks==1] # [selected_words,d_llm]
|
||||||
|
|
||||||
# 重编程与预测
|
# 重编程与预测
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
X_enc = self.reprogramming_layer(X_enc, source_embeddings, source_embeddings)
|
||||||
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state # [B,N,d_llm]
|
X_enc = self.gpts(inputs_embeds=X_enc).last_hidden_state # [B,N,d_llm]
|
||||||
dec_out = self.out_mlp(enc_out) # [B,N,pred_len]
|
dec_out = self.out_mlp(X_enc) # [B,N,pred_len]
|
||||||
|
|
||||||
# 维度调整
|
# 维度调整
|
||||||
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]
|
outputs = dec_out.unsqueeze(dim=-1) # [B,N,pred_len,1]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue