104 lines
3.6 KiB
Python
104 lines
3.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
from einops import rearrange
|
|
from model.REPST.normalizer import GumbelSoftmax
|
|
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
|
|
|
class repst(nn.Module):
|
|
|
|
def __init__(self, configs):
|
|
super(repst, self).__init__()
|
|
self.device = configs['device']
|
|
self.pred_len = configs['pred_len']
|
|
self.seq_len = configs['seq_len']
|
|
self.patch_len = configs['patch_len']
|
|
self.input_dim = configs['input_dim']
|
|
self.stride = configs['stride']
|
|
self.dropout = configs['dropout']
|
|
self.gpt_layers = configs['gpt_layers']
|
|
self.d_ff = configs['d_ff']
|
|
self.gpt_path = configs['gpt_path']
|
|
|
|
self.word_choice = GumbelSoftmax(configs['word_num'])
|
|
|
|
self.d_model = configs['d_model']
|
|
self.n_heads = configs['n_heads']
|
|
self.d_keys = None
|
|
self.d_llm = 768
|
|
|
|
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
|
self.head_nf = self.d_ff * self.patch_nums
|
|
|
|
# 词嵌入
|
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
|
|
|
# GPT2初始化
|
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
|
self.gpts.apply(self.reset_parameters)
|
|
|
|
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
|
self.vocab_size = self.word_embeddings.shape[0]
|
|
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
|
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
|
|
|
self.out_mlp = nn.Sequential(
|
|
nn.Linear(self.d_llm, 128),
|
|
nn.ReLU(),
|
|
nn.Linear(128, self.pred_len)
|
|
)
|
|
|
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
|
if 'wpe' in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
|
|
def reset_parameters(self, module):
|
|
if hasattr(module, 'weight') and module.weight is not None:
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
|
if hasattr(module, 'bias') and module.bias is not None:
|
|
torch.nn.init.zeros_(module.bias)
|
|
|
|
def forward(self, x):
|
|
x = x[..., :1]
|
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
|
enc_out, n_vars = self.patch_embedding(x_enc)
|
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
|
masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0))
|
|
source_embeddings = self.word_embeddings[masks==1]
|
|
|
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
|
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
|
|
|
dec_out = self.out_mlp(enc_out)
|
|
outputs = dec_out.unsqueeze(dim=-1)
|
|
outputs = outputs.repeat(1, 1, 1, n_vars)
|
|
outputs = outputs.permute(0,2,1,3)
|
|
|
|
return outputs
|
|
|
|
if __name__ == '__main__':
|
|
configs = {
|
|
'device': 'cuda:0',
|
|
'pred_len': 24,
|
|
'seq_len': 24,
|
|
'patch_len': 6,
|
|
'stride': 7,
|
|
'dropout': 0.2,
|
|
'gpt_layers': 9,
|
|
'd_ff': 128,
|
|
'gpt_path': './GPT-2',
|
|
'd_model': 64,
|
|
'n_heads': 1,
|
|
'input_dim': 1
|
|
}
|
|
model = repst(configs)
|
|
x = torch.randn(16, 24, 325, 1)
|
|
y = model(x)
|
|
|
|
print(y.shape)
|
|
|
|
|