From 1a13a32688503dabc99bb39cffc0da695bd2775d Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Tue, 16 Dec 2025 21:47:17 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=85=8D=E7=BD=AE=E3=80=82AS?= =?UTF-8?q?TRA=20v2=20or=20v3=E4=BD=BF=E7=94=A8=E7=A1=AC=E5=8F=82=E6=95=B0?= =?UTF-8?q?=EF=BC=8C=E7=A1=AE=E4=BF=9D=E9=85=8D=E7=BD=AE=E6=96=87=E4=BB=B6?= =?UTF-8?q?=E5=AE=8C=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/ASTRA/BJTaxi-InFlow.yaml | 1 + config/ASTRA/BJTaxi-OutFlow.yaml | 3 ++- config/ASTRA/METR-LA.yaml | 1 + config/ASTRA/NYCBike-InFlow.yaml | 3 ++- config/ASTRA/NYCBike-OutFlow.yaml | 1 + config/ASTRA/PEMS-BAY.yaml | 1 + config/ASTRA/SolarEnergy.yaml | 1 + config/ASTRA_v2/AirQuality.yaml | 3 +++ config/ASTRA_v2/BJTaxi-InFlow.yaml | 4 ++++ config/ASTRA_v2/BJTaxi-OutFlow.yaml | 4 ++++ config/ASTRA_v2/METR-LA.yaml | 4 ++++ config/ASTRA_v2/NYCBike-InFlow.yaml | 4 ++++ config/ASTRA_v2/NYCBike-OutFlow.yaml | 4 ++++ config/ASTRA_v2/PEMS-BAY.yaml | 4 ++++ config/ASTRA_v2/SolarEnergy.yaml | 4 ++++ config/ASTRA_v3/AirQuality.yaml | 4 ++++ config/ASTRA_v3/BJTaxi-InFlow.yaml | 4 ++++ config/ASTRA_v3/BJTaxi-OutFlow.yaml | 4 ++++ config/ASTRA_v3/METR-LA.yaml | 4 ++++ config/ASTRA_v3/NYCBike-InFlow.yaml | 4 ++++ config/ASTRA_v3/NYCBike-OutFlow.yaml | 4 ++++ config/ASTRA_v3/PEMS-BAY.yaml | 4 ++++ config/ASTRA_v3/SolarEnergy.yaml | 8 ++++++-- model/ASTRA/astrav2.py | 15 +++++++++------ model/ASTRA/astrav3.py | 17 ++++++++++------- model/REPST/repst.py | 2 +- train.py | 2 +- 27 files changed, 95 insertions(+), 19 deletions(-) diff --git a/config/ASTRA/BJTaxi-InFlow.yaml b/config/ASTRA/BJTaxi-InFlow.yaml index c2766bb..8569919 100644 --- a/config/ASTRA/BJTaxi-InFlow.yaml +++ b/config/ASTRA/BJTaxi-InFlow.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA/BJTaxi-OutFlow.yaml b/config/ASTRA/BJTaxi-OutFlow.yaml index ee570f3..d8f0e5d 100644 --- a/config/ASTRA/BJTaxi-OutFlow.yaml +++ b/config/ASTRA/BJTaxi-OutFlow.yaml @@ -17,7 +17,8 @@ data: steps_per_day: 48 test_ratio: 0.2 val_ratio: 0.2 - + output_dim: 1 + model: d_ff: 128 d_model: 64 diff --git a/config/ASTRA/METR-LA.yaml b/config/ASTRA/METR-LA.yaml index 87bf1ac..3ae73ec 100644 --- a/config/ASTRA/METR-LA.yaml +++ b/config/ASTRA/METR-LA.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA/NYCBike-InFlow.yaml b/config/ASTRA/NYCBike-InFlow.yaml index 1c80773..0099f8f 100644 --- a/config/ASTRA/NYCBike-InFlow.yaml +++ b/config/ASTRA/NYCBike-InFlow.yaml @@ -32,7 +32,8 @@ model: seq_len: 24 stride: 7 word_num: 1000 - + output_dim: 1 + train: batch_size: 32 debug: false diff --git a/config/ASTRA/NYCBike-OutFlow.yaml b/config/ASTRA/NYCBike-OutFlow.yaml index 1ece121..f46cece 100644 --- a/config/ASTRA/NYCBike-OutFlow.yaml +++ b/config/ASTRA/NYCBike-OutFlow.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA/PEMS-BAY.yaml b/config/ASTRA/PEMS-BAY.yaml index e111654..2b2384d 100755 --- a/config/ASTRA/PEMS-BAY.yaml +++ b/config/ASTRA/PEMS-BAY.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA/SolarEnergy.yaml b/config/ASTRA/SolarEnergy.yaml index 4160077..dd64d64 100644 --- a/config/ASTRA/SolarEnergy.yaml +++ b/config/ASTRA/SolarEnergy.yaml @@ -32,6 +32,7 @@ model: seq_len: 24 stride: 7 word_num: 1000 + output_dim: 1 train: batch_size: 64 diff --git a/config/ASTRA_v2/AirQuality.yaml b/config/ASTRA_v2/AirQuality.yaml index ed22962..9073676 100644 --- a/config/ASTRA_v2/AirQuality.yaml +++ b/config/ASTRA_v2/AirQuality.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -33,6 +34,8 @@ model: stride: 7 word_num: 1000 output_dim: 6 + graph_dim: 64 + graph_embed_dim: 10 train: batch_size: 16 diff --git a/config/ASTRA_v2/BJTaxi-InFlow.yaml b/config/ASTRA_v2/BJTaxi-InFlow.yaml index d1cc5ea..5968cca 100644 --- a/config/ASTRA_v2/BJTaxi-InFlow.yaml +++ b/config/ASTRA_v2/BJTaxi-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/BJTaxi-OutFlow.yaml b/config/ASTRA_v2/BJTaxi-OutFlow.yaml index d6e0723..03859eb 100644 --- a/config/ASTRA_v2/BJTaxi-OutFlow.yaml +++ b/config/ASTRA_v2/BJTaxi-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/METR-LA.yaml b/config/ASTRA_v2/METR-LA.yaml index dca4bb4..db6e3a8 100644 --- a/config/ASTRA_v2/METR-LA.yaml +++ b/config/ASTRA_v2/METR-LA.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v2/NYCBike-InFlow.yaml b/config/ASTRA_v2/NYCBike-InFlow.yaml index de5b6a1..caeccb7 100644 --- a/config/ASTRA_v2/NYCBike-InFlow.yaml +++ b/config/ASTRA_v2/NYCBike-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/NYCBike-OutFlow.yaml b/config/ASTRA_v2/NYCBike-OutFlow.yaml index dda718d..a586f9a 100644 --- a/config/ASTRA_v2/NYCBike-OutFlow.yaml +++ b/config/ASTRA_v2/NYCBike-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v2/PEMS-BAY.yaml b/config/ASTRA_v2/PEMS-BAY.yaml index 2f6dfbf..2705006 100755 --- a/config/ASTRA_v2/PEMS-BAY.yaml +++ b/config/ASTRA_v2/PEMS-BAY.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v2/SolarEnergy.yaml b/config/ASTRA_v2/SolarEnergy.yaml index 9b6a223..f6405a5 100644 --- a/config/ASTRA_v2/SolarEnergy.yaml +++ b/config/ASTRA_v2/SolarEnergy.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/AirQuality.yaml b/config/ASTRA_v3/AirQuality.yaml index d4cb947..c4481c0 100644 --- a/config/ASTRA_v3/AirQuality.yaml +++ b/config/ASTRA_v3/AirQuality.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -33,6 +34,9 @@ model: stride: 7 word_num: 1000 output_dim: 6 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 6 train: batch_size: 16 diff --git a/config/ASTRA_v3/BJTaxi-InFlow.yaml b/config/ASTRA_v3/BJTaxi-InFlow.yaml index 34abfd8..bb09013 100644 --- a/config/ASTRA_v3/BJTaxi-InFlow.yaml +++ b/config/ASTRA_v3/BJTaxi-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/BJTaxi-OutFlow.yaml b/config/ASTRA_v3/BJTaxi-OutFlow.yaml index 8e6b30d..0b4e8df 100644 --- a/config/ASTRA_v3/BJTaxi-OutFlow.yaml +++ b/config/ASTRA_v3/BJTaxi-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/METR-LA.yaml b/config/ASTRA_v3/METR-LA.yaml index 2b5512b..5efa494 100644 --- a/config/ASTRA_v3/METR-LA.yaml +++ b/config/ASTRA_v3/METR-LA.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/NYCBike-InFlow.yaml b/config/ASTRA_v3/NYCBike-InFlow.yaml index 18c4fa3..52008cc 100644 --- a/config/ASTRA_v3/NYCBike-InFlow.yaml +++ b/config/ASTRA_v3/NYCBike-InFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/NYCBike-OutFlow.yaml b/config/ASTRA_v3/NYCBike-OutFlow.yaml index ff73662..0977912 100644 --- a/config/ASTRA_v3/NYCBike-OutFlow.yaml +++ b/config/ASTRA_v3/NYCBike-OutFlow.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 32 diff --git a/config/ASTRA_v3/PEMS-BAY.yaml b/config/ASTRA_v3/PEMS-BAY.yaml index 6739aeb..9ff0fd0 100755 --- a/config/ASTRA_v3/PEMS-BAY.yaml +++ b/config/ASTRA_v3/PEMS-BAY.yaml @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,6 +33,9 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: batch_size: 16 diff --git a/config/ASTRA_v3/SolarEnergy.yaml b/config/ASTRA_v3/SolarEnergy.yaml index 289b839..c3f8863 100644 --- a/config/ASTRA_v3/SolarEnergy.yaml +++ b/config/ASTRA_v3/SolarEnergy.yaml @@ -6,7 +6,7 @@ basic: seed: 2023 data: - batch_size: 64 + batch_size: 16 column_wise: false days_per_week: 7 horizon: 24 @@ -19,6 +19,7 @@ data: val_ratio: 0.2 model: + cheb: 3 d_ff: 128 d_model: 64 dropout: 0.2 @@ -32,9 +33,12 @@ model: seq_len: 24 stride: 7 word_num: 1000 + graph_dim: 64 + graph_embed_dim: 10 + output_dim: 1 train: - batch_size: 64 + batch_size: 16 debug: false early_stop: true early_stop_patience: 15 diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index 22e25b9..f18ac90 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -127,8 +127,11 @@ class ASTRA(nn.Module): self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 - self.num_nodes = configs.get('num_nodes', 325) # 节点数量 - self.output_dim = configs.get('output_dim', 1) + self.num_nodes = configs['num_nodes'] # 节点数量 + self.output_dim = configs['output_dim'] + self.cheb = configs['cheb'] + self.graph_dim = configs['graph_dim'] + self.graph_embed_dim = configs['graph_embed_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -152,18 +155,18 @@ class ASTRA(nn.Module): # 初始化图增强编码器 self.graph_encoder = GraphEnhancedEncoder( - K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 + K=self.cheb, # Chebyshev多项式阶数 in_dim=self.d_model, # 输入特征维度 - hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + hidden_dim=self.graph_dim, # 隐藏层维度 num_nodes=self.num_nodes, # 节点数量 - embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + embed_dim=self.graph_embed_dim, # 节点嵌入维度 device=self.device, # 运行设备 temporal_dim=self.seq_len, # 时间序列长度 num_features=self.input_dim # 特征通道数 ) self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 - configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 + self.graph_dim * (self.cheb + 1), # 输入维度 self.d_model # 输出维度 ) diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 59fc11d..7f4317c 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -127,8 +127,11 @@ class ASTRA(nn.Module): self.gpt_layers = configs['gpt_layers'] # 使用的GPT2层数 self.d_ff = configs['d_ff'] # 前馈网络隐藏层维度 self.gpt_path = configs['gpt_path'] # 预训练GPT2模型路径 - self.num_nodes = configs.get('num_nodes', 325) # 节点数量 - self.output_dim = configs.get('output_dim', 1) + self.num_nodes = configs['num_nodes'] # 节点数量 + self.output_dim = configs['output_dim'] + self.cheb = configs['cheb'] + self.graph_dim = configs['graph_dim'] + self.graph_embed_dim = configs['graph_embed_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) # 词汇选择层 @@ -148,23 +151,23 @@ class ASTRA(nn.Module): self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) # 词嵌入权重 self.vocab_size = self.word_embeddings.shape[0] # 词汇表大小 self.mapping_layer = nn.Linear(self.vocab_size, 1) # 映射层 - self.reprogramming_layer = ReprogrammingLayer(self.d_model + configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 + self.reprogramming_layer = ReprogrammingLayer(self.d_model + self.graph_dim * (self.cheb + 1), self.n_heads, self.d_keys, self.d_llm) # 重编程层 # 初始化图增强编码器 self.graph_encoder = GraphEnhancedEncoder( K=configs.get('chebyshev_order', 3), # Chebyshev多项式阶数 in_dim=self.d_model, # 输入特征维度 - hidden_dim=configs.get('graph_hidden_dim', 32), # 隐藏层维度 + hidden_dim=self.graph_dim, # 隐藏层维度 num_nodes=self.num_nodes, # 节点数量 - embed_dim=configs.get('graph_embed_dim', 10), # 节点嵌入维度 + embed_dim=self.graph_embed_dim, # 节点嵌入维度 device=self.device, # 运行设备 temporal_dim=self.seq_len, # 时间序列长度 num_features=self.input_dim # 特征通道数 ) self.graph_projection = nn.Linear( # 图特征投影层,每一k阶的切比雪夫权重映射到隐藏维度 - configs.get('graph_hidden_dim', 32) * (configs.get('chebyshev_order', 3) + 1), # 输入维度 - self.d_model # 输出维度 + self.graph_dim * (self.cheb + 1), # 输入维度 + self.d_model # 输出维度 ) self.out_mlp = nn.Sequential( diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 5b709a4..9afbda1 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -19,7 +19,7 @@ class repst(nn.Module): self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] - self.output_dim = configs.get('output_dim', 1) + self.output_dim = configs['output_dim'] self.word_choice = GumbelSoftmax(configs['word_num']) diff --git a/train.py b/train.py index da6c058..e9db08b 100644 --- a/train.py +++ b/train.py @@ -90,7 +90,7 @@ def main(model, data, debug=False): if __name__ == "__main__": # 调试用 # model_list = ["iTransformer", "PatchTST", "HI"] - model_list = ["ASTRA_v3", "ASTRA_v2", "ASTRA", "REPST", "STAEFormer", "MTGNN", "iTransformer", "PatchTST", "HI"] + model_list = ["ASTRA_v3"] # model_list = ["MTGNN"] # dataset_list = ["AirQuality", "SolarEnergy", "PEMS-BAY", "METR-LA", "BJTaxi-InFlow", "BJTaxi-OutFlow", "NYCBike-InFlow", "NYCBike-OutFlow"] # dataset_list = ["AirQuality"]