可调word_num

This commit is contained in:
czzhangheng 2025-11-12 17:12:30 +08:00
parent 15f083c3d9
commit e61706c391
5 changed files with 67 additions and 22 deletions

View File

@ -34,6 +34,7 @@ model:
d_model: 64 d_model: 64
n_heads: 1 n_heads: 1
input_dim: 1 input_dim: 1
word_num: 2000
train: train:
batch_size: 16 batch_size: 16

View File

@ -3,7 +3,16 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def gumbel_softmax(logits, tau=1, k=1000, hard=True): class GumbelSoftmax(nn.Module):
def __init__(self, k=1000, hard=True):
super(GumbelSoftmax, self).__init__()
self.k = k
self.hard = hard
def forward(self, logits):
return self.gumbel_softmax(logits, 1, self.k, self.hard)
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard) y_soft = F.gumbel_softmax(logits, tau, hard)

View File

@ -1,8 +1,9 @@
from tkinter import Y
import torch import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange from einops import rearrange
from model.REPST.normalizer import gumbel_softmax from model.REPST.normalizer import GumbelSoftmax
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
class repst(nn.Module): class repst(nn.Module):
@ -20,6 +21,8 @@ class repst(nn.Module):
self.d_ff = configs['d_ff'] self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path'] self.gpt_path = configs['gpt_path']
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model'] self.d_model = configs['d_model']
self.n_heads = configs['n_heads'] self.n_heads = configs['n_heads']
self.d_keys = None self.d_keys = None
@ -64,7 +67,7 @@ class repst(nn.Module):
x_enc = rearrange(x, 'b t n c -> b n c t') x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc) enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0)) masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1] source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
@ -77,4 +80,25 @@ class repst(nn.Module):
return outputs return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)

View File

@ -3,7 +3,16 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
def gumbel_softmax(logits, tau=1, k=1000, hard=True): class GumbelSoftmax(nn.Module):
def __init__(self, k=1000, hard=True):
super(GumbelSoftmax, self).__init__()
self.k = k
self.hard = hard
def forward(self, logits):
return self.gumbel_softmax(logits, 1, self.k, self.hard)
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard) y_soft = F.gumbel_softmax(logits, tau, hard)

View File

@ -3,7 +3,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange from einops import rearrange
from model.REPST.normalizer import gumbel_softmax from model.REPST.normalizer import GumbelSoftmax
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
class repst(nn.Module): class repst(nn.Module):
@ -21,6 +21,8 @@ class repst(nn.Module):
self.d_ff = configs['d_ff'] self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path'] self.gpt_path = configs['gpt_path']
self.word_choice = GumbelSoftmax(configs['word_num'])
self.d_model = configs['d_model'] self.d_model = configs['d_model']
self.n_heads = configs['n_heads'] self.n_heads = configs['n_heads']
self.d_keys = None self.d_keys = None
@ -65,7 +67,7 @@ class repst(nn.Module):
x_enc = rearrange(x, 'b t n c -> b n c t') x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc) enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0)) masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1] source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)