import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange from model.REPST.normalizer import GumbelSoftmax from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer class repst(nn.Module): def __init__(self, configs): super(repst, self).__init__() self.device = configs['device'] self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] self.patch_len = configs['patch_len'] self.input_dim = configs['input_dim'] self.stride = configs['stride'] self.dropout = configs['dropout'] self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) self.d_model = configs['d_model'] self.n_heads = configs['n_heads'] self.d_keys = None self.d_llm = 768 self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) self.head_nf = self.d_ff * self.patch_nums # 词嵌入 self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) self.gpts.h = self.gpts.h[:self.gpt_layers] self.gpts.apply(self.reset_parameters) self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): if 'wpe' in name: param.requires_grad = True else: param.requires_grad = False def reset_parameters(self, module): if hasattr(module, 'weight') and module.weight is not None: torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) if hasattr(module, 'bias') and module.bias is not None: torch.nn.init.zeros_(module.bias) def forward(self, x): x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) masks = self.word_choice(self.mapping_layer.weight.data.permute(1, 0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state dec_out = self.out_mlp(enc_out) #[B, N, T*C] B, N, _ = dec_out.shape outputs = dec_out.view(B, N, self.pred_len, self.output_dim) outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs