From 35310a4f4a1a0c4122829ebe8e2feddcd4622b5f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 27 Nov 2025 19:20:59 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E6=A8=A1=E5=9E=8B=E3=80=81?= =?UTF-8?q?=E9=85=8D=E7=BD=AE=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/AEPSA/BJTaxi-Inflow.yaml | 2 +- config/AEPSA/BJTaxi-outflow.yaml | 2 +- config/STID/PEMS-BAY.yaml | 1 + config/STID/PEMSD4.yaml | 1 + model/AEPSA/aepsa.py | 5 +++-- model/AEPSA/reprogramming.py | 15 ++++----------- 6 files changed, 11 insertions(+), 15 deletions(-) diff --git a/config/AEPSA/BJTaxi-Inflow.yaml b/config/AEPSA/BJTaxi-Inflow.yaml index 2384cf5..f412db7 100644 --- a/config/AEPSA/BJTaxi-Inflow.yaml +++ b/config/AEPSA/BJTaxi-Inflow.yaml @@ -1,5 +1,5 @@ basic: - dataset: "BJTaxi-Inflow" + dataset: "BJTaxi-InFlow" mode : "train" device : "cuda:0" model: "AEPSA" diff --git a/config/AEPSA/BJTaxi-outflow.yaml b/config/AEPSA/BJTaxi-outflow.yaml index 2ce962b..acc02c4 100644 --- a/config/AEPSA/BJTaxi-outflow.yaml +++ b/config/AEPSA/BJTaxi-outflow.yaml @@ -1,5 +1,5 @@ basic: - dataset: "BJTaxi-outflow" + dataset: "BJTaxi-OutFlow" mode : "train" device : "cuda:0" model: "AEPSA" diff --git a/config/STID/PEMS-BAY.yaml b/config/STID/PEMS-BAY.yaml index 0a10a68..9178bf9 100755 --- a/config/STID/PEMS-BAY.yaml +++ b/config/STID/PEMS-BAY.yaml @@ -3,6 +3,7 @@ basic: mode: "train" device: "cuda:0" model: "STID" + seed: 2023 data: num_nodes: 325 diff --git a/config/STID/PEMSD4.yaml b/config/STID/PEMSD4.yaml index dfb0726..34b0cbc 100755 --- a/config/STID/PEMSD4.yaml +++ b/config/STID/PEMSD4.yaml @@ -3,6 +3,7 @@ basic: mode: "train" device: "cuda:0" model: "STID" + seed: 2023 data: num_nodes: 307 diff --git a/model/AEPSA/aepsa.py b/model/AEPSA/aepsa.py index 0c13da4..7ea003d 100644 --- a/model/AEPSA/aepsa.py +++ b/model/AEPSA/aepsa.py @@ -162,6 +162,7 @@ class AEPSA(nn.Module): self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] self.num_nodes = configs.get('num_nodes', 325) # 添加节点数量配置 + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) @@ -174,7 +175,7 @@ class AEPSA(nn.Module): self.head_nf = self.d_ff * self.patch_nums # 词嵌入 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) @@ -225,7 +226,7 @@ class AEPSA(nn.Module): x: 输入数据 [B, T, N, C] 自动生成图结构,无需外部提供邻接矩阵 """ - x = x[..., :1] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') # 原版Patch enc_out, n_vars = self.patch_embedding(x_enc) # (B, N, C) diff --git a/model/AEPSA/reprogramming.py b/model/AEPSA/reprogramming.py index 3b95ca6..1ba7976 100644 --- a/model/AEPSA/reprogramming.py +++ b/model/AEPSA/reprogramming.py @@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module): return output class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model, patch_num, input_dim): + def __init__(self, c_in, d_model, patch_num, input_dim, output_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) - self.confusion_layer = nn.Linear(patch_num * input_dim, 1) + self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim) for m in self.modules(): if isinstance(m, nn.Conv1d): @@ -30,7 +30,6 @@ class TokenEmbedding(nn.Module): def forward(self, x): b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len - # 为什么没permute后reshape? x = x.permute(0, 1, 4, 3, 2) x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num x = self.confusion_layer(x) @@ -38,22 +37,20 @@ class TokenEmbedding(nn.Module): class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): + def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) - self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) + self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): - n_vars = x.shape[2] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x_value_embed = self.value_embedding(x) - return self.dropout(x_value_embed), n_vars class ReprogrammingLayer(nn.Module): @@ -85,13 +82,9 @@ class ReprogrammingLayer(nn.Module): def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape - scale = 1. / sqrt(E) - scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) - A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) - return reprogramming_embedding \ No newline at end of file