diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 0198c62..6fcf12b 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -34,6 +34,7 @@ model: d_model: 64 n_heads: 1 input_dim: 1 + word_num: 2000 train: batch_size: 16 diff --git a/model/AEPSA/normalizer.py b/model/AEPSA/normalizer.py index ad6630d..fb7e182 100644 --- a/model/AEPSA/normalizer.py +++ b/model/AEPSA/normalizer.py @@ -3,14 +3,23 @@ import torch.nn as nn import torch.nn.functional as F -def gumbel_softmax(logits, tau=1, k=1000, hard=True): +class GumbelSoftmax(nn.Module): + def __init__(self, k=1000, hard=True): + super(GumbelSoftmax, self).__init__() + self.k = k + self.hard = hard - y_soft = F.gumbel_softmax(logits, tau, hard) + def forward(self, logits): + return self.gumbel_softmax(logits, 1, self.k, self.hard) - if hard: - # 生成硬掩码 - _, indices = y_soft.topk(k, dim=0) # 选择Top-K - y_hard = torch.zeros_like(logits) - y_hard.scatter_(0, indices, 1) - return torch.squeeze(y_hard, dim=-1) - return torch.squeeze(y_soft, dim=-1) \ No newline at end of file + def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): + + y_soft = F.gumbel_softmax(logits, tau, hard) + + if hard: + # 生成硬掩码 + _, indices = y_soft.topk(k, dim=0) # 选择Top-K + y_hard = torch.zeros_like(logits) + y_hard.scatter_(0, indices, 1) + return torch.squeeze(y_hard, dim=-1) + return torch.squeeze(y_soft, dim=-1) \ No newline at end of file diff --git a/model/AEPSA/repst.py b/model/AEPSA/repst.py index 9a6af69..09ceb29 100644 --- a/model/AEPSA/repst.py +++ b/model/AEPSA/repst.py @@ -1,8 +1,9 @@ +from tkinter import Y import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.REPST.normalizer import gumbel_softmax +from model.REPST.normalizer import GumbelSoftmax from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer class repst(nn.Module): @@ -20,6 +21,8 @@ class repst(nn.Module): self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] + self.word_choice = GumbelSoftmax(configs['word_num']) + self.d_model = configs['d_model'] self.n_heads = configs['n_heads'] self.d_keys = None @@ -64,7 +67,7 @@ class repst(nn.Module): x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) - masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0)) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) @@ -77,4 +80,25 @@ class repst(nn.Module): return outputs +if __name__ == '__main__': + configs = { + 'device': 'cuda:0', + 'pred_len': 24, + 'seq_len': 24, + 'patch_len': 6, + 'stride': 7, + 'dropout': 0.2, + 'gpt_layers': 9, + 'd_ff': 128, + 'gpt_path': './GPT-2', + 'd_model': 64, + 'n_heads': 1, + 'input_dim': 1 + } + model = repst(configs) + x = torch.randn(16, 24, 325, 1) + y = model(x) + + print(y.shape) + diff --git a/model/REPST/normalizer.py b/model/REPST/normalizer.py index ad6630d..fb7e182 100644 --- a/model/REPST/normalizer.py +++ b/model/REPST/normalizer.py @@ -3,14 +3,23 @@ import torch.nn as nn import torch.nn.functional as F -def gumbel_softmax(logits, tau=1, k=1000, hard=True): +class GumbelSoftmax(nn.Module): + def __init__(self, k=1000, hard=True): + super(GumbelSoftmax, self).__init__() + self.k = k + self.hard = hard - y_soft = F.gumbel_softmax(logits, tau, hard) + def forward(self, logits): + return self.gumbel_softmax(logits, 1, self.k, self.hard) - if hard: - # 生成硬掩码 - _, indices = y_soft.topk(k, dim=0) # 选择Top-K - y_hard = torch.zeros_like(logits) - y_hard.scatter_(0, indices, 1) - return torch.squeeze(y_hard, dim=-1) - return torch.squeeze(y_soft, dim=-1) \ No newline at end of file + def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): + + y_soft = F.gumbel_softmax(logits, tau, hard) + + if hard: + # 生成硬掩码 + _, indices = y_soft.topk(k, dim=0) # 选择Top-K + y_hard = torch.zeros_like(logits) + y_hard.scatter_(0, indices, 1) + return torch.squeeze(y_hard, dim=-1) + return torch.squeeze(y_soft, dim=-1) \ No newline at end of file diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 79502c2..09ceb29 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -3,7 +3,7 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.REPST.normalizer import gumbel_softmax +from model.REPST.normalizer import GumbelSoftmax from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer class repst(nn.Module): @@ -21,6 +21,8 @@ class repst(nn.Module): self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] + self.word_choice = GumbelSoftmax(configs['word_num']) + self.d_model = configs['d_model'] self.n_heads = configs['n_heads'] self.d_keys = None @@ -65,7 +67,7 @@ class repst(nn.Module): x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) - masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0)) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)