From 15f083c3d941a828acc17c44a85c07d2de275d9a Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Wed, 12 Nov 2025 16:40:13 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Dpermute=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/PEMS-BAY.yaml | 2 +- model/AEPSA/reprogramming.py | 3 ++- model/REPST/reprogramming.py | 3 ++- model/REPST/repst.py | 22 ++++++++++++++++++++++ 4 files changed, 27 insertions(+), 3 deletions(-) diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml index 54e3c38..0198c62 100755 --- a/config/REPST/PEMS-BAY.yaml +++ b/config/REPST/PEMS-BAY.yaml @@ -1,7 +1,7 @@ basic: dataset: "PEMS-BAY" mode : "train" - device : "cuda:0" + device : "cuda:1" model: "REPST" data: diff --git a/model/AEPSA/reprogramming.py b/model/AEPSA/reprogramming.py index 02a4835..3b95ca6 100644 --- a/model/AEPSA/reprogramming.py +++ b/model/AEPSA/reprogramming.py @@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module): def forward(self, x): b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len - # 768,64,25 + # 为什么没permute后reshape? + x = x.permute(0, 1, 4, 3, 2) x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num x = self.confusion_layer(x) return x.reshape(b, n, -1) diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 02a4835..3b95ca6 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module): def forward(self, x): b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len - # 768,64,25 + # 为什么没permute后reshape? + x = x.permute(0, 1, 4, 3, 2) x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num x = self.confusion_layer(x) return x.reshape(b, n, -1) diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 9a6af69..79502c2 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -1,3 +1,4 @@ +from tkinter import Y import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model @@ -77,4 +78,25 @@ class repst(nn.Module): return outputs +if __name__ == '__main__': + configs = { + 'device': 'cuda:0', + 'pred_len': 24, + 'seq_len': 24, + 'patch_len': 6, + 'stride': 7, + 'dropout': 0.2, + 'gpt_layers': 9, + 'd_ff': 128, + 'gpt_path': './GPT-2', + 'd_model': 64, + 'n_heads': 1, + 'input_dim': 1 + } + model = repst(configs) + x = torch.randn(16, 24, 325, 1) + y = model(x) + + print(y.shape) +