diff --git a/.vscode/launch.json b/.vscode/launch.json index 3dc2b03..4fb2121 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -153,7 +153,7 @@ "args": "--config ./config/REPST/AirQuality.yaml" }, - // AEPSA 模型组 + // ASTRA 模型组 { "name": "AEPSA: PEMS-BAY", "type": "debugpy", diff --git a/config/AEPSA/AirQuality.yaml b/config/AEPSA/AirQuality.yaml index d6061d9..455fc4b 100644 --- a/config/AEPSA/AirQuality.yaml +++ b/config/AEPSA/AirQuality.yaml @@ -2,7 +2,7 @@ basic: dataset: AirQuality device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-InFlow.yaml b/config/AEPSA/BJTaxi-InFlow.yaml index a453b38..c2766bb 100644 --- a/config/AEPSA/BJTaxi-InFlow.yaml +++ b/config/AEPSA/BJTaxi-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-Inflow.yaml b/config/AEPSA/BJTaxi-Inflow.yaml index a453b38..c2766bb 100644 --- a/config/AEPSA/BJTaxi-Inflow.yaml +++ b/config/AEPSA/BJTaxi-Inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-OutFlow.yaml b/config/AEPSA/BJTaxi-OutFlow.yaml index 9fa0f5f..ee570f3 100644 --- a/config/AEPSA/BJTaxi-OutFlow.yaml +++ b/config/AEPSA/BJTaxi-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/BJTaxi-outflow.yaml b/config/AEPSA/BJTaxi-outflow.yaml index 9fa0f5f..ee570f3 100644 --- a/config/AEPSA/BJTaxi-outflow.yaml +++ b/config/AEPSA/BJTaxi-outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: BJTaxi-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/METR-LA.yaml b/config/AEPSA/METR-LA.yaml index a623226..87bf1ac 100644 --- a/config/AEPSA/METR-LA.yaml +++ b/config/AEPSA/METR-LA.yaml @@ -2,7 +2,7 @@ basic: dataset: METR-LA device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-InFlow.yaml b/config/AEPSA/NYCBike-InFlow.yaml index b561493..1c80773 100644 --- a/config/AEPSA/NYCBike-InFlow.yaml +++ b/config/AEPSA/NYCBike-InFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-OutFlow.yaml b/config/AEPSA/NYCBike-OutFlow.yaml index 5c4da71..1ece121 100644 --- a/config/AEPSA/NYCBike-OutFlow.yaml +++ b/config/AEPSA/NYCBike-OutFlow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-inflow.yaml b/config/AEPSA/NYCBike-inflow.yaml index e4ba138..5431fba 100644 --- a/config/AEPSA/NYCBike-inflow.yaml +++ b/config/AEPSA/NYCBike-inflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-InFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/NYCBike-outflow.yaml b/config/AEPSA/NYCBike-outflow.yaml index 7cb6798..194c330 100644 --- a/config/AEPSA/NYCBike-outflow.yaml +++ b/config/AEPSA/NYCBike-outflow.yaml @@ -2,7 +2,7 @@ basic: dataset: NYCBike-OutFlow device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/PEMS-BAY.yaml b/config/AEPSA/PEMS-BAY.yaml index f75c63a..e111654 100755 --- a/config/AEPSA/PEMS-BAY.yaml +++ b/config/AEPSA/PEMS-BAY.yaml @@ -2,7 +2,7 @@ basic: dataset: PEMS-BAY device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/config/AEPSA/SolarEnergy.yaml b/config/AEPSA/SolarEnergy.yaml index 669c9f4..4160077 100644 --- a/config/AEPSA/SolarEnergy.yaml +++ b/config/AEPSA/SolarEnergy.yaml @@ -2,7 +2,7 @@ basic: dataset: SolarEnergy device: cuda:0 mode: train - model: AEPSA + model: ASTRA seed: 2023 data: diff --git a/model/ASTRA/astra.py b/model/ASTRA/astra.py index 7ea003d..0ed2333 100644 --- a/model/ASTRA/astra.py +++ b/model/ASTRA/astra.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer import torch.nn.functional as F class DynamicGraphEnhancer(nn.Module): @@ -147,10 +147,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) -class AEPSA(nn.Module): +class ASTRA(nn.Module): def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] self.pred_len = configs['pred_len'] self.seq_len = configs['seq_len'] diff --git a/model/ASTRA/astrav2.py b/model/ASTRA/astrav2.py index aac9149..79a1330 100644 --- a/model/ASTRA/astrav2.py +++ b/model/ASTRA/astrav2.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import ReprogrammingLayer import torch.nn.functional as F # 基于动态图增强的时空序列预测模型实现 @@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 -class AEPSA(nn.Module): +class ASTRA(nn.Module): """自适应特征投影时空自注意力模型""" def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] # 运行设备 self.pred_len = configs['pred_len'] # 预测序列长度 self.seq_len = configs['seq_len'] # 输入序列长度 diff --git a/model/ASTRA/astrav3.py b/model/ASTRA/astrav3.py index 99c6748..a29bfc3 100644 --- a/model/ASTRA/astrav3.py +++ b/model/ASTRA/astrav3.py @@ -2,8 +2,8 @@ import torch import torch.nn as nn from transformers.models.gpt2.modeling_gpt2 import GPT2Model from einops import rearrange -from model.AEPSA.normalizer import GumbelSoftmax -from model.AEPSA.reprogramming import ReprogrammingLayer +from model.ASTRA.normalizer import GumbelSoftmax +from model.ASTRA.reprogramming import ReprogrammingLayer import torch.nn.functional as F # 基于动态图增强的时空序列预测模型实现 @@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module): return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征 -class AEPSA(nn.Module): +class ASTRA(nn.Module): """自适应特征投影时空自注意力模型""" def __init__(self, configs): - super(AEPSA, self).__init__() + super(ASTRA, self).__init__() self.device = configs['device'] # 运行设备 self.pred_len = configs['pred_len'] # 预测序列长度 self.seq_len = configs['seq_len'] # 输入序列长度 diff --git a/model/model_selector.py b/model/model_selector.py index 633b02c..da54b33 100755 --- a/model/model_selector.py +++ b/model/model_selector.py @@ -23,9 +23,9 @@ from model.ST_SSL.ST_SSL import STSSLModel from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STAWnet.STAWnet import STAWnet from model.REPST.repst import repst as REPST -from model.AEPSA.aepsa import AEPSA as AEPSA -from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 -from model.AEPSA.aepsav3 import AEPSA as AEPSAv3 +from model.ASTRA.astra import ASTRA as ASTRA +from model.ASTRA.astrav2 import ASTRA as ASTRAv2 +from model.ASTRA.astrav3 import ASTRA as ASTRAv3 @@ -83,9 +83,9 @@ def model_selector(config): return STAWnet(model_config) case "REPST": return REPST(model_config) - case "AEPSA": - return AEPSA(model_config) - case "AEPSA_v2": - return AEPSAv2(model_config) - case "AEPSA_v3": - return AEPSAv3(model_config) + case "ASTRA": + return ASTRA(model_config) + case "ASTRA_v2": + return ASTRAv2(model_config) + case "ASTRA_v3": + return ASTRAv3(model_config) diff --git a/trainer/DCRNN_Trainer.py b/trainer/DCRNN_Trainer.py index a60eddb..1911248 100755 --- a/trainer/DCRNN_Trainer.py +++ b/trainer/DCRNN_Trainer.py @@ -104,14 +104,14 @@ class Trainer: loss = self.loss(output, label) # 检查output和label的shape是否一致 - if output.shape == label.shape: - print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") - import sys - sys.exit(0) - else: - print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") - import sys - sys.exit(1) + # if output.shape == label.shape: + # print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") + # import sys + # sys.exit(0) + # else: + # print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") + # import sys + # sys.exit(1) # 反归一化 d_output = self.scaler.inverse_transform(output)