Compare commits
2 Commits
2f58fc9348
...
0ad1494aec
| Author | SHA1 | Date |
|---|---|---|
|
|
0ad1494aec | |
|
|
35310a4f4a |
|
|
@ -1,5 +1,5 @@
|
|||
basic:
|
||||
dataset: "BJTaxi-Inflow"
|
||||
dataset: "BJTaxi-InFlow"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "AEPSA"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
basic:
|
||||
dataset: "BJTaxi-outflow"
|
||||
dataset: "BJTaxi-OutFlow"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "AEPSA"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
basic:
|
||||
dataset: "NYCBike-inflow"
|
||||
dataset: "NYCBike-InFlow"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "AEPSA"
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
basic:
|
||||
dataset: "NYCBike-outflow"
|
||||
dataset: "NYCBike-OutFlow"
|
||||
mode : "train"
|
||||
device : "cuda:0"
|
||||
model: "AEPSA"
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ basic:
|
|||
mode: "train"
|
||||
device: "cuda:0"
|
||||
model: "STID"
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
num_nodes: 325
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ basic:
|
|||
mode: "train"
|
||||
device: "cuda:0"
|
||||
model: "STID"
|
||||
seed: 2023
|
||||
|
||||
data:
|
||||
num_nodes: 307
|
||||
|
|
|
|||
|
|
@ -162,6 +162,7 @@ class AEPSA(nn.Module):
|
|||
self.d_ff = configs['d_ff']
|
||||
self.gpt_path = configs['gpt_path']
|
||||
self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置
|
||||
self.output_dim = configs.get('output_dim', 1)
|
||||
|
||||
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||
|
||||
|
|
@ -174,7 +175,7 @@ class AEPSA(nn.Module):
|
|||
self.head_nf = self.d_ff * self.patch_nums
|
||||
|
||||
# 词嵌入
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
|
||||
|
||||
# GPT2初始化
|
||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||
|
|
@ -225,7 +226,7 @@ class AEPSA(nn.Module):
|
|||
x: 输入数据 [B, T, N, C]
|
||||
自动生成图结构,无需外部提供邻接矩阵
|
||||
"""
|
||||
x = x[..., :1]
|
||||
x = x[..., :self.input_dim]
|
||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||
# 原版Patch
|
||||
enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C)
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module):
|
|||
return output
|
||||
|
||||
class TokenEmbedding(nn.Module):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
||||
def __init__(self, c_in, d_model, patch_num, input_dim, output_dim):
|
||||
super(TokenEmbedding, self).__init__()
|
||||
padding = 1
|
||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
||||
self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
|
|
@ -30,7 +30,6 @@ class TokenEmbedding(nn.Module):
|
|||
|
||||
def forward(self, x):
|
||||
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
|
||||
# 为什么没permute后reshape?
|
||||
x = x.permute(0, 1, 4, 3, 2)
|
||||
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
|
||||
x = self.confusion_layer(x)
|
||||
|
|
@ -38,22 +37,20 @@ class TokenEmbedding(nn.Module):
|
|||
|
||||
|
||||
class PatchEmbedding(nn.Module):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim):
|
||||
super(PatchEmbedding, self).__init__()
|
||||
# Patching
|
||||
self.patch_len = patch_len
|
||||
self.stride = stride
|
||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
n_vars = x.shape[2]
|
||||
x = self.padding_patch_layer(x)
|
||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||
x_value_embed = self.value_embedding(x)
|
||||
|
||||
return self.dropout(x_value_embed), n_vars
|
||||
|
||||
class ReprogrammingLayer(nn.Module):
|
||||
|
|
@ -85,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
|
|||
|
||||
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
||||
B, L, H, E = target_embedding.shape
|
||||
|
||||
scale = 1. / sqrt(E)
|
||||
|
||||
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
||||
|
||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
||||
|
||||
return reprogramming_embedding
|
||||
|
||||
Loading…
Reference in New Issue