更名ASTRA

This commit is contained in:
czzhangheng 2025-12-09 14:07:38 +08:00
parent aed1e53f0f
commit 9c76975056
18 changed files with 42 additions and 42 deletions

2
.vscode/launch.json vendored
View File

@ -153,7 +153,7 @@
"args": "--config ./config/REPST/AirQuality.yaml"
},
// AEPSA
// ASTRA
{
"name": "AEPSA: PEMS-BAY",
"type": "debugpy",

View File

@ -2,7 +2,7 @@ basic:
dataset: AirQuality
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-InFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-InFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-OutFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-OutFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: METR-LA
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-InFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-OutFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-InFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-OutFlow
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: PEMS-BAY
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,7 +2,7 @@ basic:
dataset: SolarEnergy
device: cuda:0
mode: train
model: AEPSA
model: ASTRA
seed: 2023
data:

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer
from model.ASTRA.normalizer import GumbelSoftmax
from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module):
@ -147,10 +147,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0)
class AEPSA(nn.Module):
class ASTRA(nn.Module):
def __init__(self, configs):
super(AEPSA, self).__init__()
super(ASTRA, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer
from model.ASTRA.normalizer import GumbelSoftmax
from model.ASTRA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module):
class ASTRA(nn.Module):
"""自适应特征投影时空自注意力模型"""
def __init__(self, configs):
super(AEPSA, self).__init__()
super(ASTRA, self).__init__()
self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer
from model.ASTRA.normalizer import GumbelSoftmax
from model.ASTRA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module):
class ASTRA(nn.Module):
"""自适应特征投影时空自注意力模型"""
def __init__(self, configs):
super(AEPSA, self).__init__()
super(ASTRA, self).__init__()
self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度

View File

@ -23,9 +23,9 @@ from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST
from model.AEPSA.aepsa import AEPSA as AEPSA
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2
from model.AEPSA.aepsav3 import AEPSA as AEPSAv3
from model.ASTRA.astra import ASTRA as ASTRA
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
@ -83,9 +83,9 @@ def model_selector(config):
return STAWnet(model_config)
case "REPST":
return REPST(model_config)
case "AEPSA":
return AEPSA(model_config)
case "AEPSA_v2":
return AEPSAv2(model_config)
case "AEPSA_v3":
return AEPSAv3(model_config)
case "ASTRA":
return ASTRA(model_config)
case "ASTRA_v2":
return ASTRAv2(model_config)
case "ASTRA_v3":
return ASTRAv3(model_config)

View File

@ -104,14 +104,14 @@ class Trainer:
loss = self.loss(output, label)
# 检查output和label的shape是否一致
if output.shape == label.shape:
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
import sys
sys.exit(0)
else:
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
import sys
sys.exit(1)
# if output.shape == label.shape:
# print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
# import sys
# sys.exit(0)
# else:
# print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
# import sys
# sys.exit(1)
# 反归一化
d_output = self.scaler.inverse_transform(output)