修复permute的bug
This commit is contained in:
parent
095e8c60dc
commit
15f083c3d9
|
|
@ -1,7 +1,7 @@
|
|||
basic:
|
||||
dataset: "PEMS-BAY"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
device : "cuda:1"
|
||||
model: "REPST"
|
||||
|
||||
data:
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||
# 768,64,25
|
||||
# 为什么没permute后reshape?
|
||||
x = x.permute(0, 1, 4, 3, 2)
|
||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||
x = self.confusion_layer(x)
|
||||
return x.reshape(b, n, -1)
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||
# 768,64,25
|
||||
# 为什么没permute后reshape?
|
||||
x = x.permute(0, 1, 4, 3, 2)
|
||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||
x = self.confusion_layer(x)
|
||||
return x.reshape(b, n, -1)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from tkinter import Y
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
|
@ -77,4 +78,25 @@ class repst(nn.Module):
|
|||
|
||||
return outputs
|
||||
|
||||
if __name__ == '__main__':
|
||||
configs = {
|
||||
'device': 'cuda:0',
|
||||
'pred_len': 24,
|
||||
'seq_len': 24,
|
||||
'patch_len': 6,
|
||||
'stride': 7,
|
||||
'dropout': 0.2,
|
||||
'gpt_layers': 9,
|
||||
'd_ff': 128,
|
||||
'gpt_path': './GPT-2',
|
||||
'd_model': 64,
|
||||
'n_heads': 1,
|
||||
'input_dim': 1
|
||||
}
|
||||
model = repst(configs)
|
||||
x = torch.randn(16, 24, 325, 1)
|
||||
y = model(x)
|
||||
|
||||
print(y.shape)
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue