46 lines
1.8 KiB
Python
46 lines
1.8 KiB
Python
import torch.nn as nn
|
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
|
from einops import rearrange
|
|
|
|
class fpt(nn.Module):
|
|
def __init__(self, configs):
|
|
super(fpt, self).__init__()
|
|
self.patch_len = configs['patch_len']
|
|
self.stride = configs['stride']
|
|
self.input_dim = configs['input_dim']
|
|
self.seq_len = configs['seq_len']
|
|
self.pred_len = configs['pred_len']
|
|
self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数
|
|
self.d_model = configs['d_model']
|
|
self.gpt_path = configs['gpt_path']
|
|
|
|
self.patch_num = int((self.seq_len - self.patch_len) / self.stride + 2) # 补丁数量
|
|
self.padding_patch_layer = nn.ReplicationPad1d((0, self.stride))
|
|
|
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
|
if 'wpe' in name:
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
|
|
self.in_layer = nn.Linear(self.patch_len, self.d_model)
|
|
self.out_layer = nn.Linear(self.d_model * self.patch_num, self.pred_len)
|
|
|
|
def forward(self, x):
|
|
B, L, M = x.shape
|
|
x = x[..., :self.input_dim]
|
|
x = rearrange(x, 'b l m -> b m l')
|
|
|
|
x = self.padding_patch_layer(x)
|
|
x = x.unfold(dimension = -1, size = self.patch_len, step = self.stride)
|
|
x = rearrange(x, 'b m n p -> (b m) n p')
|
|
|
|
outputs = self.in_layer(x)
|
|
outputs = self.gpts(inputs_embeds=outputs).last_hidden_state
|
|
outputs = self.out_layer(outputs.reshape(B*M, -1))
|
|
outputs = rearrange(outputs, '(b m) l -> b l m', b = B)
|
|
return outputs
|
|
|