From a3e43fc6df807786f328dac979d3b3a0764bdb08 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 16:27:32 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=94=AF=E6=8C=81REPST=20In?= =?UTF-8?q?=20D8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 4 +- .vscode/launch.json | 8 +++ config/PDG2SEQ/PEMSD8.yaml | 2 +- config/REPST/PEMSD8.yaml | 58 +++++++++++++++++++ model/REPST/normalizer.py | 105 +++++++++++++++++++++++++++++++++++ model/REPST/reprogramming.py | 99 +++++++++++++++++++++++++++++++++ model/REPST/repst.py | 87 +++++++++++++++++++++++++++++ model/model_selector.py | 3 + 8 files changed, 364 insertions(+), 2 deletions(-) create mode 100755 config/REPST/PEMSD8.yaml create mode 100644 model/REPST/normalizer.py create mode 100644 model/REPST/reprogramming.py create mode 100644 model/REPST/repst.py diff --git a/.gitignore b/.gitignore index 18182c9..d7124b8 100755 --- a/.gitignore +++ b/.gitignore @@ -172,4 +172,6 @@ cython_debug/ .DS_Store Result.xlsx .temp_repo/ -.exp/ \ No newline at end of file +.exp/ +GPT-2/config.json +GPT-2/pytorch_model.bin diff --git a/.vscode/launch.json b/.vscode/launch.json index f383630..3f71813 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -11,6 +11,14 @@ "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/DDGCRN/PEMSD8.yaml" + }, + { + "name": "REPST", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/PEMSD8.yaml" } ] } \ No newline at end of file diff --git a/config/PDG2SEQ/PEMSD8.yaml b/config/PDG2SEQ/PEMSD8.yaml index 9c6693a..c1f7f6e 100755 --- a/config/PDG2SEQ/PEMSD8.yaml +++ b/config/PDG2SEQ/PEMSD8.yaml @@ -38,7 +38,7 @@ train: batch_size: 64 early_stop: true early_stop_patience: 15 - epochs: 300 + epochs: 1 grad_norm: false loss_func: mae lr_decay: true diff --git a/config/REPST/PEMSD8.yaml b/config/REPST/PEMSD8.yaml new file mode 100755 index 0000000..f4b970b --- /dev/null +++ b/config/REPST/PEMSD8.yaml @@ -0,0 +1,58 @@ +basic: + dataset: "PEMSD8" + mode : "train" + device : "cuda:0" + model: "REPST" + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 170 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 64 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + +train: + batch_size: 64 + early_stop: true + early_stop_patience: 15 + epochs: 300 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + seed: 12 + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 2000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/model/REPST/normalizer.py b/model/REPST/normalizer.py new file mode 100644 index 0000000..f437265 --- /dev/null +++ b/model/REPST/normalizer.py @@ -0,0 +1,105 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +def gumbel_softmax(logits, tau=1, k=1000, hard=True): + + y_soft = F.gumbel_softmax(logits, tau, hard) + + if hard: + # 生成硬掩码 + _, indices = y_soft.topk(k, dim=0) # 选择Top-K + y_hard = torch.zeros_like(logits) + y_hard.scatter_(0, indices, 1) + return torch.squeeze(y_hard, dim=-1) + return torch.squeeze(y_soft, dim=-1) + + + +class Normalize(nn.Module): + def __init__(self, num_features: int, eps=1e-5, affine=False, subtract_last=False, non_norm=False): + """ + :param num_features: the number of features or channels + :param eps: a value added for numerical stability + :param affine: if True, RevIN has learnable affine parameters + """ + super(Normalize, self).__init__() + self.num_features = num_features + self.eps = eps + self.affine = affine + self.subtract_last = subtract_last + self.non_norm = non_norm + if self.affine: + self._init_params() + + def forward(self, x, mode: str): + if mode == 'norm': + self._get_statistics(x) + x = self._normalize(x) + elif mode == 'denorm': + x = self._denormalize(x) + else: + raise NotImplementedError + return x + + def _init_params(self): + # initialize RevIN params: (C,) + self.affine_weight = nn.Parameter(torch.ones(self.num_features)) + self.affine_bias = nn.Parameter(torch.zeros(self.num_features)) + + def _get_statistics(self, x): + dim2reduce = tuple(range(1, x.ndim - 1)) + if self.subtract_last: + self.last = x[:, -1, :].unsqueeze(1) + else: + self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach() + self.stdev = torch.sqrt(torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps).detach() + + def _normalize(self, x): + if self.non_norm: + return x + if self.subtract_last: + x = x - self.last + else: + x = x - self.mean + x = x / self.stdev + if self.affine: + x = x * self.affine_weight + x = x + self.affine_bias + return x + + def _denormalize(self, x): + if self.non_norm: + return x + if self.affine: + x = x - self.affine_bias + x = x / (self.affine_weight + self.eps * self.eps) + x = x * self.stdev + if self.subtract_last: + x = x + self.last + else: + x = x + self.mean + return x + + +class MultiLayerPerceptron(nn.Module): + """Multi-Layer Perceptron with residual links.""" + + def __init__(self, input_dim, hidden_dim) -> None: + super().__init__() + self.fc1 = nn.Conv2d( + in_channels=input_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) + self.fc2 = nn.Conv2d( + in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=(1, 1), bias=True) + self.act = nn.ReLU() + self.drop = nn.Dropout(p=0.15) + + def forward(self, input_data: torch.Tensor) -> torch.Tensor: + """ + input_data (torch.Tensor): input data with shape [B, D, N] + """ + + hidden = self.fc2(self.drop(self.act(self.fc1(input_data)))) # MLP + hidden = hidden + input_data # residual + return hidden \ No newline at end of file diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py new file mode 100644 index 0000000..f5e9663 --- /dev/null +++ b/model/REPST/reprogramming.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +from torch import Tensor +from math import sqrt + + +class ReplicationPad1d(nn.Module): + def __init__(self, padding) -> None: + super(ReplicationPad1d, self).__init__() + self.padding = padding + + def forward(self, input: Tensor) -> Tensor: + replicate_padding = input[:, :, :, -1].unsqueeze(-1).repeat(1, 1, 1, self.padding[-1]) + output = torch.cat([input, replicate_padding], dim=-1) + return output + +class TokenEmbedding(nn.Module): + def __init__(self, c_in, d_model): + super(TokenEmbedding, self).__init__() + padding = 1 + self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, + kernel_size=3, padding=padding, padding_mode='circular', bias=False) + self.confusion_layer = nn.Linear(2, 1) + # if air_quality + # self.confusion_layer = nn.Linear(42, 1) + + + for m in self.modules(): + if isinstance(m, nn.Conv1d): + nn.init.kaiming_normal_( + m.weight, mode='fan_in', nonlinearity='leaky_relu') + + def forward(self, x): + b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len + # 768,64,25 + x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num + + x = self.confusion_layer(x) + return x.reshape(b, n, -1) + + +class PatchEmbedding(nn.Module): + def __init__(self, d_model, patch_len, stride, dropout): + super(PatchEmbedding, self).__init__() + # Patching + self.patch_len = patch_len + self.stride = stride + self.padding_patch_layer = ReplicationPad1d((0, stride)) + self.value_embedding = TokenEmbedding(patch_len, d_model) + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + + n_vars = x.shape[2] + x = self.padding_patch_layer(x) + x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) + x_value_embed = self.value_embedding(x) + + return self.dropout(x_value_embed), n_vars + +class ReprogrammingLayer(nn.Module): + def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1): + super(ReprogrammingLayer, self).__init__() + + d_keys = d_keys or (d_model // n_heads) + + self.query_projection = nn.Linear(d_model, d_keys * n_heads) + self.key_projection = nn.Linear(d_llm, d_keys * n_heads) + self.value_projection = nn.Linear(d_llm, d_keys * n_heads) + self.out_projection = nn.Linear(d_keys * n_heads, d_llm) + self.n_heads = n_heads + self.dropout = nn.Dropout(attention_dropout) + + def forward(self, target_embedding, source_embedding, value_embedding): + B, L, _ = target_embedding.shape + S, _ = source_embedding.shape + H = self.n_heads + + target_embedding = self.query_projection(target_embedding).view(B, L, H, -1) + source_embedding = self.key_projection(source_embedding).view(S, H, -1) + value_embedding = self.value_projection(value_embedding).view(S, H, -1) + + out = self.reprogramming(target_embedding, source_embedding, value_embedding) + out = out.reshape(B, L, -1) + + return self.out_projection(out) + + def reprogramming(self, target_embedding, source_embedding, value_embedding): + B, L, H, E = target_embedding.shape + + scale = 1. / sqrt(E) + + scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) + + A = self.dropout(torch.softmax(scale * scores, dim=-1)) + reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) + + return reprogramming_embedding + \ No newline at end of file diff --git a/model/REPST/repst.py b/model/REPST/repst.py new file mode 100644 index 0000000..66468ea --- /dev/null +++ b/model/REPST/repst.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn +from transformers.models.gpt2.modeling_gpt2 import GPT2Model +from einops import rearrange +from model.REPST.normalizer import Normalize, gumbel_softmax +from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer + +class repst(nn.Module): + + def __init__(self, configs): + super(repst, self).__init__() + self.device = configs['device'] + self.pred_len = configs['pred_len'] + self.seq_len = configs['seq_len'] + self.patch_len = configs['patch_len'] + self.stride = configs['stride'] + self.dropout = configs['dropout'] + self.gpt_layers = configs['gpt_layers'] + self.d_ff = configs['d_ff'] # output mapping dimension + self.gpt_path = configs['gpt_path'] + + self.d_model = configs['d_model'] + self.n_heads = configs['n_heads'] + self.d_keys = None + self.d_llm = 768 + + self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2) + self.head_nf = self.d_ff * self.patch_nums + + # 64,6,7,0.2 + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout) + + self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) + self.gpts.h = self.gpts.h[:self.gpt_layers] + + self.gpts.apply(self.reset_parameters) + + self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) + self.vocab_size = self.word_embeddings.shape[0] + self.num_tokens = 1000 + self.n_vars = 5 + + + self.normalize_layers = Normalize(num_features=1, affine=False) + self.mapping_layer = nn.Linear(self.vocab_size, 1) + + self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + + self.out_mlp = nn.Sequential( + nn.Linear(self.d_llm, 128), + nn.ReLU(), + nn.Linear(128, self.pred_len) + ) + + for i, (name, param) in enumerate(self.gpts.named_parameters()): + if 'wpe' in name: + param.requires_grad = True + else: + param.requires_grad = False + + def reset_parameters(self, module): + if hasattr(module, 'weight') and module.weight is not None: + torch.nn.init.normal_(module.weight, mean=0.0, std=0.02) + if hasattr(module, 'bias') and module.bias is not None: + torch.nn.init.zeros_(module.bias) + + + + def forward(self, x): + x = x[..., :1] + x_enc = rearrange(x, 'b t n c -> b n c t') + enc_out, n_vars = self.patch_embedding(x_enc) + self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) + masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0)) + source_embeddings = self.word_embeddings[masks==1] + + enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) + enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state + + dec_out = self.out_mlp(enc_out) + outputs = dec_out.unsqueeze(dim=-1) + outputs = outputs.repeat(1, 1, 1, n_vars) + outputs = outputs.permute(0,2,1,3) + + return outputs + + diff --git a/model/model_selector.py b/model/model_selector.py index 0fc01da..80169d9 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -22,6 +22,7 @@ from model.MegaCRN.MegaCRNModel import MegaCRNModel from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet +from model.REPST.repst import repst as REPST def model_selector(config): @@ -76,3 +77,5 @@ def model_selector(config): return make_nrde_model(model_config) case "STAWnet": return STAWnet(model_config) + case "REPST": + return REPST(model_config)