添加支持REPST In D8
This commit is contained in:
parent
b02c9c91d7
commit
a3e43fc6df
|
|
@ -173,3 +173,5 @@ cython_debug/
|
||||||
Result.xlsx
|
Result.xlsx
|
||||||
.temp_repo/
|
.temp_repo/
|
||||||
.exp/
|
.exp/
|
||||||
|
GPT-2/config.json
|
||||||
|
GPT-2/pytorch_model.bin
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,14 @@
|
||||||
"program": "run.py",
|
"program": "run.py",
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/DDGCRN/PEMSD8.yaml"
|
"args": "--config ./config/DDGCRN/PEMSD8.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "REPST",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/PEMSD8.yaml"
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -38,7 +38,7 @@ train:
|
||||||
batch_size: 64
|
batch_size: 64
|
||||||
early_stop: true
|
early_stop: true
|
||||||
early_stop_patience: 15
|
early_stop_patience: 15
|
||||||
epochs: 300
|
epochs: 1
|
||||||
grad_norm: false
|
grad_norm: false
|
||||||
loss_func: mae
|
loss_func: mae
|
||||||
lr_decay: true
|
lr_decay: true
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,58 @@
|
||||||
|
basic:
|
||||||
|
dataset: "PEMSD8"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "REPST"
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 12
|
||||||
|
lag: 12
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 64
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 12
|
||||||
|
seq_len: 12
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 300
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
seed: 12
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 2000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,105 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
|
||||||
|
|
||||||
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||||
|
|
||||||
|
if hard:
|
||||||
|
# 生成硬掩码
|
||||||
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||||
|
y_hard = torch.zeros_like(logits)
|
||||||
|
y_hard.scatter_(0, indices, 1)
|
||||||
|
return torch.squeeze(y_hard, dim=-1)
|
||||||
|
return torch.squeeze(y_soft, dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Normalize(nn.Module):
|
||||||
|
def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False):
|
||||||
|
"""
|
||||||
|
:param num_features: the number of features or channels
|
||||||
|
:param eps: a value added for numerical stability
|
||||||
|
:param affine: if True, RevIN has learnable affine parameters
|
||||||
|
"""
|
||||||
|
super(Normalize, self).__init__()
|
||||||
|
self.num_features = num_features
|
||||||
|
self.eps = eps
|
||||||
|
self.affine = affine
|
||||||
|
self.subtract_last = subtract_last
|
||||||
|
self.non_norm = non_norm
|
||||||
|
if self.affine:
|
||||||
|
self._init_params()
|
||||||
|
|
||||||
|
def forward(self, x, mode: str):
|
||||||
|
if mode == 'norm':
|
||||||
|
self._get_statistics(x)
|
||||||
|
x = self._normalize(x)
|
||||||
|
elif mode == 'denorm':
|
||||||
|
x = self._denormalize(x)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _init_params(self):
|
||||||
|
# initialize RevIN params: (C,)
|
||||||
|
self.affine_weight = nn.Parameter(torch.ones(self.num_features))
|
||||||
|
self.affine_bias = nn.Parameter(torch.zeros(self.num_features))
|
||||||
|
|
||||||
|
def _get_statistics(self, x):
|
||||||
|
dim2reduce = tuple(range(1, x.ndim - 1))
|
||||||
|
if self.subtract_last:
|
||||||
|
self.last = x[:, -1, :].unsqueeze(1)
|
||||||
|
else:
|
||||||
|
self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
|
||||||
|
self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach()
|
||||||
|
|
||||||
|
def _normalize(self, x):
|
||||||
|
if self.non_norm:
|
||||||
|
return x
|
||||||
|
if self.subtract_last:
|
||||||
|
x = x - self.last
|
||||||
|
else:
|
||||||
|
x = x - self.mean
|
||||||
|
x = x / self.stdev
|
||||||
|
if self.affine:
|
||||||
|
x = x * self.affine_weight
|
||||||
|
x = x + self.affine_bias
|
||||||
|
return x
|
||||||
|
|
||||||
|
def _denormalize(self, x):
|
||||||
|
if self.non_norm:
|
||||||
|
return x
|
||||||
|
if self.affine:
|
||||||
|
x = x - self.affine_bias
|
||||||
|
x = x / (self.affine_weight + self.eps * self.eps)
|
||||||
|
x = x * self.stdev
|
||||||
|
if self.subtract_last:
|
||||||
|
x = x + self.last
|
||||||
|
else:
|
||||||
|
x = x + self.mean
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLayerPerceptron(nn.Module):
|
||||||
|
"""Multi-Layer Perceptron with residual links."""
|
||||||
|
|
||||||
|
def __init__(self, input_dim, hidden_dim) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.fc1 = nn.Conv2d(
|
||||||
|
in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
||||||
|
self.fc2 = nn.Conv2d(
|
||||||
|
in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True)
|
||||||
|
self.act = nn.ReLU()
|
||||||
|
self.drop = nn.Dropout(p=0.15)
|
||||||
|
|
||||||
|
def forward(self, input_data: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
input_data (torch.Tensor): input data with shape [B, D, N]
|
||||||
|
"""
|
||||||
|
|
||||||
|
hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP
|
||||||
|
hidden = hidden + input_data # residual
|
||||||
|
return hidden
|
||||||
|
|
@ -0,0 +1,99 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch import Tensor
|
||||||
|
from math import sqrt
|
||||||
|
|
||||||
|
|
||||||
|
class ReplicationPad1d(nn.Module):
|
||||||
|
def __init__(self, padding) -> None:
|
||||||
|
super(ReplicationPad1d, self).__init__()
|
||||||
|
self.padding = padding
|
||||||
|
|
||||||
|
def forward(self, input: Tensor) -> Tensor:
|
||||||
|
replicate_padding = input[:, :, :, -1].unsqueeze(-1).repeat(1, 1, 1, self.padding[-1])
|
||||||
|
output = torch.cat([input, replicate_padding], dim=-1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
class TokenEmbedding(nn.Module):
|
||||||
|
def __init__(self, c_in, d_model):
|
||||||
|
super(TokenEmbedding, self).__init__()
|
||||||
|
padding = 1
|
||||||
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||||
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||||
|
self.confusion_layer = nn.Linear(2, 1)
|
||||||
|
# if air_quality
|
||||||
|
# self.confusion_layer = nn.Linear(42, 1)
|
||||||
|
|
||||||
|
|
||||||
|
for m in self.modules():
|
||||||
|
if isinstance(m, nn.Conv1d):
|
||||||
|
nn.init.kaiming_normal_(
|
||||||
|
m.weight, mode='fan_in', nonlinearity='leaky_relu')
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
|
# 768,64,25
|
||||||
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
|
|
||||||
|
x = self.confusion_layer(x)
|
||||||
|
return x.reshape(b, n, -1)
|
||||||
|
|
||||||
|
|
||||||
|
class PatchEmbedding(nn.Module):
|
||||||
|
def __init__(self, d_model, patch_len, stride, dropout):
|
||||||
|
super(PatchEmbedding, self).__init__()
|
||||||
|
# Patching
|
||||||
|
self.patch_len = patch_len
|
||||||
|
self.stride = stride
|
||||||
|
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||||
|
self.value_embedding = TokenEmbedding(patch_len, d_model)
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
|
||||||
|
n_vars = x.shape[2]
|
||||||
|
x = self.padding_patch_layer(x)
|
||||||
|
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||||
|
x_value_embed = self.value_embedding(x)
|
||||||
|
|
||||||
|
return self.dropout(x_value_embed), n_vars
|
||||||
|
|
||||||
|
class ReprogrammingLayer(nn.Module):
|
||||||
|
def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
|
||||||
|
super(ReprogrammingLayer, self).__init__()
|
||||||
|
|
||||||
|
d_keys = d_keys or (d_model // n_heads)
|
||||||
|
|
||||||
|
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
|
||||||
|
self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
|
||||||
|
self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
|
||||||
|
self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
|
||||||
|
self.n_heads = n_heads
|
||||||
|
self.dropout = nn.Dropout(attention_dropout)
|
||||||
|
|
||||||
|
def forward(self, target_embedding, source_embedding, value_embedding):
|
||||||
|
B, L, _ = target_embedding.shape
|
||||||
|
S, _ = source_embedding.shape
|
||||||
|
H = self.n_heads
|
||||||
|
|
||||||
|
target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
|
||||||
|
source_embedding = self.key_projection(source_embedding).view(S, H, -1)
|
||||||
|
value_embedding = self.value_projection(value_embedding).view(S, H, -1)
|
||||||
|
|
||||||
|
out = self.reprogramming(target_embedding, source_embedding, value_embedding)
|
||||||
|
out = out.reshape(B, L, -1)
|
||||||
|
|
||||||
|
return self.out_projection(out)
|
||||||
|
|
||||||
|
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
||||||
|
B, L, H, E = target_embedding.shape
|
||||||
|
|
||||||
|
scale = 1. / sqrt(E)
|
||||||
|
|
||||||
|
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
||||||
|
|
||||||
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||||
|
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
||||||
|
|
||||||
|
return reprogramming_embedding
|
||||||
|
|
||||||
|
|
@ -0,0 +1,87 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
from einops import rearrange
|
||||||
|
from model.REPST.normalizer import Normalize, gumbel_softmax
|
||||||
|
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
|
|
||||||
|
class repst(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, configs):
|
||||||
|
super(repst, self).__init__()
|
||||||
|
self.device = configs['device']
|
||||||
|
self.pred_len = configs['pred_len']
|
||||||
|
self.seq_len = configs['seq_len']
|
||||||
|
self.patch_len = configs['patch_len']
|
||||||
|
self.stride = configs['stride']
|
||||||
|
self.dropout = configs['dropout']
|
||||||
|
self.gpt_layers = configs['gpt_layers']
|
||||||
|
self.d_ff = configs['d_ff'] # output mapping dimension
|
||||||
|
self.gpt_path = configs['gpt_path']
|
||||||
|
|
||||||
|
self.d_model = configs['d_model']
|
||||||
|
self.n_heads = configs['n_heads']
|
||||||
|
self.d_keys = None
|
||||||
|
self.d_llm = 768
|
||||||
|
|
||||||
|
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
|
||||||
|
self.head_nf = self.d_ff * self.patch_nums
|
||||||
|
|
||||||
|
# 64,6,7,0.2
|
||||||
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout)
|
||||||
|
|
||||||
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||||
|
self.gpts.h = self.gpts.h[:self.gpt_layers]
|
||||||
|
|
||||||
|
self.gpts.apply(self.reset_parameters)
|
||||||
|
|
||||||
|
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
||||||
|
self.vocab_size = self.word_embeddings.shape[0]
|
||||||
|
self.num_tokens = 1000
|
||||||
|
self.n_vars = 5
|
||||||
|
|
||||||
|
|
||||||
|
self.normalize_layers = Normalize(num_features=1, affine=False)
|
||||||
|
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
||||||
|
|
||||||
|
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
||||||
|
|
||||||
|
self.out_mlp = nn.Sequential(
|
||||||
|
nn.Linear(self.d_llm, 128),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(128, self.pred_len)
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
||||||
|
if 'wpe' in name:
|
||||||
|
param.requires_grad = True
|
||||||
|
else:
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
def reset_parameters(self, module):
|
||||||
|
if hasattr(module, 'weight') and module.weight is not None:
|
||||||
|
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
||||||
|
if hasattr(module, 'bias') and module.bias is not None:
|
||||||
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = x[..., :1]
|
||||||
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
|
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||||
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
|
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0))
|
||||||
|
source_embeddings = self.word_embeddings[masks==1]
|
||||||
|
|
||||||
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
|
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
||||||
|
|
||||||
|
dec_out = self.out_mlp(enc_out)
|
||||||
|
outputs = dec_out.unsqueeze(dim=-1)
|
||||||
|
outputs = outputs.repeat(1, 1, 1, n_vars)
|
||||||
|
outputs = outputs.permute(0,2,1,3)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -22,6 +22,7 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel
|
||||||
from model.ST_SSL.ST_SSL import STSSLModel
|
from model.ST_SSL.ST_SSL import STSSLModel
|
||||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||||
from model.STAWnet.STAWnet import STAWnet
|
from model.STAWnet.STAWnet import STAWnet
|
||||||
|
from model.REPST.repst import repst as REPST
|
||||||
|
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
|
|
@ -76,3 +77,5 @@ def model_selector(config):
|
||||||
return make_nrde_model(model_config)
|
return make_nrde_model(model_config)
|
||||||
case "STAWnet":
|
case "STAWnet":
|
||||||
return STAWnet(model_config)
|
return STAWnet(model_config)
|
||||||
|
case "REPST":
|
||||||
|
return REPST(model_config)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue