修复permute的bug
This commit is contained in:
parent
095e8c60dc
commit
15f083c3d9
|
|
@ -1,7 +1,7 @@
|
||||||
basic:
|
basic:
|
||||||
dataset: "PEMS-BAY"
|
dataset: "PEMS-BAY"
|
||||||
mode : "train"
|
mode : "train"
|
||||||
device : "cuda:0"
|
device : "cuda:1"
|
||||||
model: "REPST"
|
model: "REPST"
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
# 768,64,25
|
# 为什么没permute后reshape?
|
||||||
|
x = x.permute(0, 1, 4, 3, 2)
|
||||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
x = self.confusion_layer(x)
|
x = self.confusion_layer(x)
|
||||||
return x.reshape(b, n, -1)
|
return x.reshape(b, n, -1)
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,8 @@ class TokenEmbedding(nn.Module):
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||||
# 768,64,25
|
# 为什么没permute后reshape?
|
||||||
|
x = x.permute(0, 1, 4, 3, 2)
|
||||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||||
x = self.confusion_layer(x)
|
x = self.confusion_layer(x)
|
||||||
return x.reshape(b, n, -1)
|
return x.reshape(b, n, -1)
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from tkinter import Y
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
|
|
@ -77,4 +78,25 @@ class repst(nn.Module):
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
configs = {
|
||||||
|
'device': 'cuda:0',
|
||||||
|
'pred_len': 24,
|
||||||
|
'seq_len': 24,
|
||||||
|
'patch_len': 6,
|
||||||
|
'stride': 7,
|
||||||
|
'dropout': 0.2,
|
||||||
|
'gpt_layers': 9,
|
||||||
|
'd_ff': 128,
|
||||||
|
'gpt_path': './GPT-2',
|
||||||
|
'd_model': 64,
|
||||||
|
'n_heads': 1,
|
||||||
|
'input_dim': 1
|
||||||
|
}
|
||||||
|
model = repst(configs)
|
||||||
|
x = torch.randn(16, 24, 325, 1)
|
||||||
|
y = model(x)
|
||||||
|
|
||||||
|
print(y.shape)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue