from tkinter import Y import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange from model.REPST.normalizer import GumbelSoftmax from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer class repst(nn.Module): def __init__(self, configs): super(repst, self).__init__() self.device = configs['device'] self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] self.patch_len = configs['patch_len'] self.input_dim = configs['input_dim'] self.stride = configs['stride'] self.dropout = configs['dropout'] self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] self.word_choice = GumbelSoftmax(configs['word_num']) self.d_model = configs['d_model'] self.n_heads = configs['n_heads'] self.d_keys = None self.d_llm = 768 self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) self.head_nf = self.d_ff * self.patch_nums # 词嵌入 self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) self.gpts.h = self.gpts.h[:self.gpt_layers] self.gpts.apply(self.reset_parameters) self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), nn.Linear(128, self.pred_len) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): if 'wpe' in name: param.requires_grad = True else: param.requires_grad = False def reset_parameters(self, module): if hasattr(module, 'weight') and module.weight is not None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: torch.nn.init.zeros_(module.bias) def forward(self, x): x = x[..., :1] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state dec_out = self.out_mlp(enc_out) outputs = dec_out.unsqueeze(dim=-1) outputs = outputs.repeat(1, 1, 1, n_vars) outputs = outputs.permute(0,2,1,3) return outputs if __name__ == '__main__': configs = { 'device': 'cuda:0', 'pred_len': 24, 'seq_len': 24, 'patch_len': 6, 'stride': 7, 'dropout': 0.2, 'gpt_layers': 9, 'd_ff': 128, 'gpt_path': './GPT-2', 'd_model': 64, 'n_heads': 1, 'input_dim': 1 } model = repst(configs) x = torch.randn(16, 24, 325, 1) y = model(x) print(y.shape)