可调word_num
This commit is contained in:
parent
15f083c3d9
commit
e61706c391
|
|
@ -34,6 +34,7 @@ model:
|
||||||
d_model: 64
|
d_model: 64
|
||||||
n_heads: 1
|
n_heads: 1
|
||||||
input_dim: 1
|
input_dim: 1
|
||||||
|
word_num: 2000
|
||||||
|
|
||||||
train:
|
train:
|
||||||
batch_size: 16
|
batch_size: 16
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,23 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
|
class GumbelSoftmax(nn.Module):
|
||||||
|
def __init__(self, k=1000, hard=True):
|
||||||
|
super(GumbelSoftmax, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.hard = hard
|
||||||
|
|
||||||
y_soft = F.gumbel_softmax(logits, tau, hard)
|
def forward(self, logits):
|
||||||
|
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
||||||
|
|
||||||
if hard:
|
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
||||||
# 生成硬掩码
|
|
||||||
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||||
y_hard = torch.zeros_like(logits)
|
|
||||||
y_hard.scatter_(0, indices, 1)
|
if hard:
|
||||||
return torch.squeeze(y_hard, dim=-1)
|
# 生成硬掩码
|
||||||
return torch.squeeze(y_soft, dim=-1)
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||||
|
y_hard = torch.zeros_like(logits)
|
||||||
|
y_hard.scatter_(0, indices, 1)
|
||||||
|
return torch.squeeze(y_hard, dim=-1)
|
||||||
|
return torch.squeeze(y_soft, dim=-1)
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
|
from tkinter import Y
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.REPST.normalizer import gumbel_softmax
|
from model.REPST.normalizer import GumbelSoftmax
|
||||||
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
|
|
||||||
class repst(nn.Module):
|
class repst(nn.Module):
|
||||||
|
|
@ -20,6 +21,8 @@ class repst(nn.Module):
|
||||||
self.d_ff = configs['d_ff']
|
self.d_ff = configs['d_ff']
|
||||||
self.gpt_path = configs['gpt_path']
|
self.gpt_path = configs['gpt_path']
|
||||||
|
|
||||||
|
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||||
|
|
||||||
self.d_model = configs['d_model']
|
self.d_model = configs['d_model']
|
||||||
self.n_heads = configs['n_heads']
|
self.n_heads = configs['n_heads']
|
||||||
self.d_keys = None
|
self.d_keys = None
|
||||||
|
|
@ -64,7 +67,7 @@ class repst(nn.Module):
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0))
|
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
|
||||||
source_embeddings = self.word_embeddings[masks==1]
|
source_embeddings = self.word_embeddings[masks==1]
|
||||||
|
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
|
|
@ -77,4 +80,25 @@ class repst(nn.Module):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
configs = {
|
||||||
|
'device': 'cuda:0',
|
||||||
|
'pred_len': 24,
|
||||||
|
'seq_len': 24,
|
||||||
|
'patch_len': 6,
|
||||||
|
'stride': 7,
|
||||||
|
'dropout': 0.2,
|
||||||
|
'gpt_layers': 9,
|
||||||
|
'd_ff': 128,
|
||||||
|
'gpt_path': './GPT-2',
|
||||||
|
'd_model': 64,
|
||||||
|
'n_heads': 1,
|
||||||
|
'input_dim': 1
|
||||||
|
}
|
||||||
|
model = repst(configs)
|
||||||
|
x = torch.randn(16, 24, 325, 1)
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
|
print(y.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,14 +3,23 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
|
class GumbelSoftmax(nn.Module):
|
||||||
|
def __init__(self, k=1000, hard=True):
|
||||||
|
super(GumbelSoftmax, self).__init__()
|
||||||
|
self.k = k
|
||||||
|
self.hard = hard
|
||||||
|
|
||||||
y_soft = F.gumbel_softmax(logits, tau, hard)
|
def forward(self, logits):
|
||||||
|
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
||||||
|
|
||||||
if hard:
|
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
||||||
# 生成硬掩码
|
|
||||||
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||||
y_hard = torch.zeros_like(logits)
|
|
||||||
y_hard.scatter_(0, indices, 1)
|
if hard:
|
||||||
return torch.squeeze(y_hard, dim=-1)
|
# 生成硬掩码
|
||||||
return torch.squeeze(y_soft, dim=-1)
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||||
|
y_hard = torch.zeros_like(logits)
|
||||||
|
y_hard.scatter_(0, indices, 1)
|
||||||
|
return torch.squeeze(y_hard, dim=-1)
|
||||||
|
return torch.squeeze(y_soft, dim=-1)
|
||||||
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.REPST.normalizer import gumbel_softmax
|
from model.REPST.normalizer import GumbelSoftmax
|
||||||
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
|
|
||||||
class repst(nn.Module):
|
class repst(nn.Module):
|
||||||
|
|
@ -21,6 +21,8 @@ class repst(nn.Module):
|
||||||
self.d_ff = configs['d_ff']
|
self.d_ff = configs['d_ff']
|
||||||
self.gpt_path = configs['gpt_path']
|
self.gpt_path = configs['gpt_path']
|
||||||
|
|
||||||
|
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||||
|
|
||||||
self.d_model = configs['d_model']
|
self.d_model = configs['d_model']
|
||||||
self.n_heads = configs['n_heads']
|
self.n_heads = configs['n_heads']
|
||||||
self.d_keys = None
|
self.d_keys = None
|
||||||
|
|
@ -65,7 +67,7 @@ class repst(nn.Module):
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0))
|
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
|
||||||
source_embeddings = self.word_embeddings[masks==1]
|
source_embeddings = self.word_embeddings[masks==1]
|
||||||
|
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue