更名ASTRA
This commit is contained in:
parent
aed1e53f0f
commit
9c76975056
|
|
@ -153,7 +153,7 @@
|
||||||
"args": "--config ./config/REPST/AirQuality.yaml"
|
"args": "--config ./config/REPST/AirQuality.yaml"
|
||||||
},
|
},
|
||||||
|
|
||||||
// AEPSA 模型组
|
// ASTRA 模型组
|
||||||
{
|
{
|
||||||
"name": "AEPSA: PEMS-BAY",
|
"name": "AEPSA: PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: AirQuality
|
dataset: AirQuality
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: BJTaxi-InFlow
|
dataset: BJTaxi-InFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: BJTaxi-InFlow
|
dataset: BJTaxi-InFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: BJTaxi-OutFlow
|
dataset: BJTaxi-OutFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: BJTaxi-OutFlow
|
dataset: BJTaxi-OutFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: METR-LA
|
dataset: METR-LA
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: NYCBike-InFlow
|
dataset: NYCBike-InFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: NYCBike-OutFlow
|
dataset: NYCBike-OutFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: NYCBike-InFlow
|
dataset: NYCBike-InFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: NYCBike-OutFlow
|
dataset: NYCBike-OutFlow
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: PEMS-BAY
|
dataset: PEMS-BAY
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ basic:
|
||||||
dataset: SolarEnergy
|
dataset: SolarEnergy
|
||||||
device: cuda:0
|
device: cuda:0
|
||||||
mode: train
|
mode: train
|
||||||
model: AEPSA
|
model: ASTRA
|
||||||
seed: 2023
|
seed: 2023
|
||||||
|
|
||||||
data:
|
data:
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.AEPSA.normalizer import GumbelSoftmax
|
from model.ASTRA.normalizer import GumbelSoftmax
|
||||||
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer
|
from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
class DynamicGraphEnhancer(nn.Module):
|
class DynamicGraphEnhancer(nn.Module):
|
||||||
|
|
@ -147,10 +147,10 @@ class GraphEnhancedEncoder(nn.Module):
|
||||||
|
|
||||||
return torch.stack(enhanced_features, dim=0)
|
return torch.stack(enhanced_features, dim=0)
|
||||||
|
|
||||||
class AEPSA(nn.Module):
|
class ASTRA(nn.Module):
|
||||||
|
|
||||||
def __init__(self, configs):
|
def __init__(self, configs):
|
||||||
super(AEPSA, self).__init__()
|
super(ASTRA, self).__init__()
|
||||||
self.device = configs['device']
|
self.device = configs['device']
|
||||||
self.pred_len = configs['pred_len']
|
self.pred_len = configs['pred_len']
|
||||||
self.seq_len = configs['seq_len']
|
self.seq_len = configs['seq_len']
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.AEPSA.normalizer import GumbelSoftmax
|
from model.ASTRA.normalizer import GumbelSoftmax
|
||||||
from model.AEPSA.reprogramming import ReprogrammingLayer
|
from model.ASTRA.reprogramming import ReprogrammingLayer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# 基于动态图增强的时空序列预测模型实现
|
# 基于动态图增强的时空序列预测模型实现
|
||||||
|
|
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
|
||||||
|
|
||||||
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征
|
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征
|
||||||
|
|
||||||
class AEPSA(nn.Module):
|
class ASTRA(nn.Module):
|
||||||
"""自适应特征投影时空自注意力模型"""
|
"""自适应特征投影时空自注意力模型"""
|
||||||
def __init__(self, configs):
|
def __init__(self, configs):
|
||||||
super(AEPSA, self).__init__()
|
super(ASTRA, self).__init__()
|
||||||
self.device = configs['device'] # 运行设备
|
self.device = configs['device'] # 运行设备
|
||||||
self.pred_len = configs['pred_len'] # 预测序列长度
|
self.pred_len = configs['pred_len'] # 预测序列长度
|
||||||
self.seq_len = configs['seq_len'] # 输入序列长度
|
self.seq_len = configs['seq_len'] # 输入序列长度
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from model.AEPSA.normalizer import GumbelSoftmax
|
from model.ASTRA.normalizer import GumbelSoftmax
|
||||||
from model.AEPSA.reprogramming import ReprogrammingLayer
|
from model.ASTRA.reprogramming import ReprogrammingLayer
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
# 基于动态图增强的时空序列预测模型实现
|
# 基于动态图增强的时空序列预测模型实现
|
||||||
|
|
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
|
||||||
|
|
||||||
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征
|
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)],每个节点在每个k阶下的切比雪夫特征
|
||||||
|
|
||||||
class AEPSA(nn.Module):
|
class ASTRA(nn.Module):
|
||||||
"""自适应特征投影时空自注意力模型"""
|
"""自适应特征投影时空自注意力模型"""
|
||||||
def __init__(self, configs):
|
def __init__(self, configs):
|
||||||
super(AEPSA, self).__init__()
|
super(ASTRA, self).__init__()
|
||||||
self.device = configs['device'] # 运行设备
|
self.device = configs['device'] # 运行设备
|
||||||
self.pred_len = configs['pred_len'] # 预测序列长度
|
self.pred_len = configs['pred_len'] # 预测序列长度
|
||||||
self.seq_len = configs['seq_len'] # 输入序列长度
|
self.seq_len = configs['seq_len'] # 输入序列长度
|
||||||
|
|
|
||||||
|
|
@ -23,9 +23,9 @@ from model.ST_SSL.ST_SSL import STSSLModel
|
||||||
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
from model.STGNRDE.Make_model import make_model as make_nrde_model
|
||||||
from model.STAWnet.STAWnet import STAWnet
|
from model.STAWnet.STAWnet import STAWnet
|
||||||
from model.REPST.repst import repst as REPST
|
from model.REPST.repst import repst as REPST
|
||||||
from model.AEPSA.aepsa import AEPSA as AEPSA
|
from model.ASTRA.astra import ASTRA as ASTRA
|
||||||
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2
|
from model.ASTRA.astrav2 import ASTRA as ASTRAv2
|
||||||
from model.AEPSA.aepsav3 import AEPSA as AEPSAv3
|
from model.ASTRA.astrav3 import ASTRA as ASTRAv3
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,9 +83,9 @@ def model_selector(config):
|
||||||
return STAWnet(model_config)
|
return STAWnet(model_config)
|
||||||
case "REPST":
|
case "REPST":
|
||||||
return REPST(model_config)
|
return REPST(model_config)
|
||||||
case "AEPSA":
|
case "ASTRA":
|
||||||
return AEPSA(model_config)
|
return ASTRA(model_config)
|
||||||
case "AEPSA_v2":
|
case "ASTRA_v2":
|
||||||
return AEPSAv2(model_config)
|
return ASTRAv2(model_config)
|
||||||
case "AEPSA_v3":
|
case "ASTRA_v3":
|
||||||
return AEPSAv3(model_config)
|
return ASTRAv3(model_config)
|
||||||
|
|
|
||||||
|
|
@ -104,14 +104,14 @@ class Trainer:
|
||||||
loss = self.loss(output, label)
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
# 检查output和label的shape是否一致
|
# 检查output和label的shape是否一致
|
||||||
if output.shape == label.shape:
|
# if output.shape == label.shape:
|
||||||
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
# print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
|
||||||
import sys
|
# import sys
|
||||||
sys.exit(0)
|
# sys.exit(0)
|
||||||
else:
|
# else:
|
||||||
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
# print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
|
||||||
import sys
|
# import sys
|
||||||
sys.exit(1)
|
# sys.exit(1)
|
||||||
|
|
||||||
# 反归一化
|
# 反归一化
|
||||||
d_output = self.scaler.inverse_transform(output)
|
d_output = self.scaler.inverse_transform(output)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue