修复permute的bug

This commit is contained in:
czzhangheng 2025-11-12 16:40:13 +08:00
parent 095e8c60dc
commit 15f083c3d9
4 changed files with 27 additions and 3 deletions

View File

@ -1,7 +1,7 @@
basic:
dataset: "PEMS-BAY"
mode : "train"
device : "cuda:0"
device : "cuda:1"
model: "REPST"
data:

View File

@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
def forward(self, x):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
# 为什么没permute后reshape?
x = x.permute(0, 1, 4, 3, 2)
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)

View File

@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
def forward(self, x):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
# 为什么没permute后reshape?
x = x.permute(0, 1, 4, 3, 2)
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)

View File

@ -1,3 +1,4 @@
from tkinter import Y
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
@ -77,4 +78,25 @@ class repst(nn.Module):
return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)