清理repst冗余代码
This commit is contained in:
parent
8e53d25ab1
commit
6657743afe
|
|
@ -14,92 +14,3 @@ def gumbel_softmax(logits, tau=1, k=1000, hard=True):
|
||||||
y_hard.scatter_(0, indices, 1)
|
y_hard.scatter_(0, indices, 1)
|
||||||
return torch.squeeze(y_hard, dim=-1)
|
return torch.squeeze(y_hard, dim=-1)
|
||||||
return torch.squeeze(y_soft, dim=-1)
|
return torch.squeeze(y_soft, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Normalize(nn.Module):
|
|
||||||
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
|
||||||
"""
|
|
||||||
:param num_features: the number of features or channels
|
|
||||||
:param eps: a value added for numerical stability
|
|
||||||
:param affine: if True, RevIN has learnable affine parameters
|
|
||||||
"""
|
|
||||||
super(Normalize, self).__init__()
|
|
||||||
self.num_features = num_features
|
|
||||||
self.eps = eps
|
|
||||||
self.affine = affine
|
|
||||||
self.subtract_last = subtract_last
|
|
||||||
self.non_norm = non_norm
|
|
||||||
if self.affine:
|
|
||||||
self._init_params()
|
|
||||||
|
|
||||||
def forward(self, x, mode: str):
|
|
||||||
if mode == 'norm':
|
|
||||||
self._get_statistics(x)
|
|
||||||
x = self._normalize(x)
|
|
||||||
elif mode == 'denorm':
|
|
||||||
x = self._denormalize(x)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _init_params(self):
|
|
||||||
# initialize RevIN params: (C,)
|
|
||||||
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
|
||||||
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
|
||||||
|
|
||||||
def _get_statistics(self, x):
|
|
||||||
dim2reduce = tuple(range(1, x.ndim - 1))
|
|
||||||
if self.subtract_last:
|
|
||||||
self.last = x[:, -1, :].unsqueeze(1)
|
|
||||||
else:
|
|
||||||
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
|
||||||
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
|
||||||
|
|
||||||
def _normalize(self, x):
|
|
||||||
if self.non_norm:
|
|
||||||
return x
|
|
||||||
if self.subtract_last:
|
|
||||||
x = x - self.last
|
|
||||||
else:
|
|
||||||
x = x - self.mean
|
|
||||||
x = x / self.stdev
|
|
||||||
if self.affine:
|
|
||||||
x = x * self.affine_weight
|
|
||||||
x = x + self.affine_bias
|
|
||||||
return x
|
|
||||||
|
|
||||||
def _denormalize(self, x):
|
|
||||||
if self.non_norm:
|
|
||||||
return x
|
|
||||||
if self.affine:
|
|
||||||
x = x - self.affine_bias
|
|
||||||
x = x / (self.affine_weight + self.eps * self.eps)
|
|
||||||
x = x * self.stdev
|
|
||||||
if self.subtract_last:
|
|
||||||
x = x + self.last
|
|
||||||
else:
|
|
||||||
x = x + self.mean
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class MultiLayerPerceptron(nn.Module):
|
|
||||||
"""Multi-Layer Perceptron with residual links."""
|
|
||||||
|
|
||||||
def __init__(self, input_dim, hidden_dim) -> None:
|
|
||||||
super().__init__()
|
|
||||||
self.fc1 = nn.Conv2d(
|
|
||||||
in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
|
||||||
self.fc2 = nn.Conv2d(
|
|
||||||
in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
|
||||||
self.act = nn.ReLU()
|
|
||||||
self.drop = nn.Dropout(p=0.15)
|
|
||||||
|
|
||||||
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
input_data (torch.Tensor): input data with shape [B, D, N]
|
|
||||||
"""
|
|
||||||
|
|
||||||
hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP
|
|
||||||
hidden = hidden + input_data # residual
|
|
||||||
return hidden
|
|
||||||
|
|
@ -22,9 +22,6 @@ class TokenEmbedding(nn.Module):
|
||||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||||
|
|
||||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
||||||
# if air_quality
|
|
||||||
# self.confusion_layer = nn.Linear(42, 1)
|
|
||||||
|
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv1d):
|
if isinstance(m, nn.Conv1d):
|
||||||
|
|
@ -59,7 +56,7 @@ class PatchEmbedding(nn.Module):
|
||||||
return self.dropout(x_value_embed), n_vars
|
return self.dropout(x_value_embed), n_vars
|
||||||
|
|
||||||
class ReprogrammingLayer(nn.Module):
|
class ReprogrammingLayer(nn.Module):
|
||||||
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
|
def __init__(self, d_model, n_heads, d_keys, d_llm, attention_dropout=0.1):
|
||||||
super(ReprogrammingLayer, self).__init__()
|
super(ReprogrammingLayer, self).__init__()
|
||||||
|
|
||||||
d_keys = d_keys or (d_model // n_heads)
|
d_keys = d_keys or (d_model // n_heads)
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.REPST.normalizer import Normalize, gumbel_softmax
|
from model.REPST.normalizer import gumbel_softmax
|
||||||
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
|
|
||||||
class repst(nn.Module):
|
class repst(nn.Module):
|
||||||
|
|
@ -17,7 +17,7 @@ class repst(nn.Module):
|
||||||
self.stride = configs['stride']
|
self.stride = configs['stride']
|
||||||
self.dropout = configs['dropout']
|
self.dropout = configs['dropout']
|
||||||
self.gpt_layers = configs['gpt_layers']
|
self.gpt_layers = configs['gpt_layers']
|
||||||
self.d_ff = configs['d_ff'] # output mapping dimension
|
self.d_ff = configs['d_ff']
|
||||||
self.gpt_path = configs['gpt_path']
|
self.gpt_path = configs['gpt_path']
|
||||||
|
|
||||||
self.d_model = configs['d_model']
|
self.d_model = configs['d_model']
|
||||||
|
|
@ -28,23 +28,17 @@ class repst(nn.Module):
|
||||||
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
||||||
self.head_nf = self.d_ff * self.patch_nums
|
self.head_nf = self.d_ff * self.patch_nums
|
||||||
|
|
||||||
# 64,6,7,0.2
|
# 词嵌入
|
||||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||||
|
|
||||||
|
# GPT2初始化
|
||||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||||
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||||
|
|
||||||
self.gpts.apply(self.reset_parameters)
|
self.gpts.apply(self.reset_parameters)
|
||||||
|
|
||||||
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
||||||
self.vocab_size = self.word_embeddings.shape[0]
|
self.vocab_size = self.word_embeddings.shape[0]
|
||||||
self.num_tokens = 1000
|
|
||||||
self.n_vars = 5
|
|
||||||
|
|
||||||
|
|
||||||
self.normalize_layers = Normalize(num_features=1, affine=False)
|
|
||||||
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
||||||
|
|
||||||
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
||||||
|
|
||||||
self.out_mlp = nn.Sequential(
|
self.out_mlp = nn.Sequential(
|
||||||
|
|
@ -65,8 +59,6 @@ class repst(nn.Module):
|
||||||
if hasattr(module, 'bias') and module.bias is not None:
|
if hasattr(module, 'bias') and module.bias is not None:
|
||||||
torch.nn.init.zeros_(module.bias)
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x[..., :1]
|
x = x[..., :1]
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue