REPST #3

Merged
czzhangheng merged 42 commits from REPST into main 2025-12-20 16:03:22 +08:00
18 changed files with 42 additions and 42 deletions
Showing only changes of commit 9c76975056 - Show all commits

2
.vscode/launch.json vendored
View File

@ -153,7 +153,7 @@
"args": "--config ./config/REPST/AirQuality.yaml" "args": "--config ./config/REPST/AirQuality.yaml"
}, },
// AEPSA // ASTRA
{ {
"name": "AEPSA: PEMS-BAY", "name": "AEPSA: PEMS-BAY",
"type": "debugpy", "type": "debugpy",

View File

@ -2,7 +2,7 @@ basic:
dataset: AirQuality dataset: AirQuality
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-InFlow dataset: BJTaxi-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-InFlow dataset: BJTaxi-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-OutFlow dataset: BJTaxi-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: BJTaxi-OutFlow dataset: BJTaxi-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: METR-LA dataset: METR-LA
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-InFlow dataset: NYCBike-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-OutFlow dataset: NYCBike-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-InFlow dataset: NYCBike-InFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: NYCBike-OutFlow dataset: NYCBike-OutFlow
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: PEMS-BAY dataset: PEMS-BAY
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,7 +2,7 @@ basic:
dataset: SolarEnergy dataset: SolarEnergy
device: cuda:0 device: cuda:0
mode: train mode: train
model: AEPSA model: ASTRA
seed: 2023 seed: 2023
data: data:

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax from model.ASTRA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import PatchEmbedding, ReprogrammingLayer from model.ASTRA.reprogramming import PatchEmbedding, ReprogrammingLayer
import torch.nn.functional as F import torch.nn.functional as F
class DynamicGraphEnhancer(nn.Module): class DynamicGraphEnhancer(nn.Module):
@ -147,10 +147,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0) return torch.stack(enhanced_features, dim=0)
class AEPSA(nn.Module): class ASTRA(nn.Module):
def __init__(self, configs): def __init__(self, configs):
super(AEPSA, self).__init__() super(ASTRA, self).__init__()
self.device = configs['device'] self.device = configs['device']
self.pred_len = configs['pred_len'] self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len'] self.seq_len = configs['seq_len']

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax from model.ASTRA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer from model.ASTRA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现 # 基于动态图增强的时空序列预测模型实现
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征 return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module): class ASTRA(nn.Module):
"""自适应特征投影时空自注意力模型""" """自适应特征投影时空自注意力模型"""
def __init__(self, configs): def __init__(self, configs):
super(AEPSA, self).__init__() super(ASTRA, self).__init__()
self.device = configs['device'] # 运行设备 self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度 self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度 self.seq_len = configs['seq_len'] # 输入序列长度

View File

@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange from einops import rearrange
from model.AEPSA.normalizer import GumbelSoftmax from model.ASTRA.normalizer import GumbelSoftmax
from model.AEPSA.reprogramming import ReprogrammingLayer from model.ASTRA.reprogramming import ReprogrammingLayer
import torch.nn.functional as F import torch.nn.functional as F
# 基于动态图增强的时空序列预测模型实现 # 基于动态图增强的时空序列预测模型实现
@ -113,10 +113,10 @@ class GraphEnhancedEncoder(nn.Module):
return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征 return torch.stack(enhanced_features, dim=0) # 堆叠返回[B,N,hidden_dim*(K+1)]每个节点在每个k阶下的切比雪夫特征
class AEPSA(nn.Module): class ASTRA(nn.Module):
"""自适应特征投影时空自注意力模型""" """自适应特征投影时空自注意力模型"""
def __init__(self, configs): def __init__(self, configs):
super(AEPSA, self).__init__() super(ASTRA, self).__init__()
self.device = configs['device'] # 运行设备 self.device = configs['device'] # 运行设备
self.pred_len = configs['pred_len'] # 预测序列长度 self.pred_len = configs['pred_len'] # 预测序列长度
self.seq_len = configs['seq_len'] # 输入序列长度 self.seq_len = configs['seq_len'] # 输入序列长度

View File

@ -23,9 +23,9 @@ from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST from model.REPST.repst import repst as REPST
from model.AEPSA.aepsa import AEPSA as AEPSA from model.ASTRA.astra import ASTRA as ASTRA
from model.AEPSA.aepsav2 import AEPSA as AEPSAv2 from model.ASTRA.astrav2 import ASTRA as ASTRAv2
from model.AEPSA.aepsav3 import AEPSA as AEPSAv3 from model.ASTRA.astrav3 import ASTRA as ASTRAv3
@ -83,9 +83,9 @@ def model_selector(config):
return STAWnet(model_config) return STAWnet(model_config)
case "REPST": case "REPST":
return REPST(model_config) return REPST(model_config)
case "AEPSA": case "ASTRA":
return AEPSA(model_config) return ASTRA(model_config)
case "AEPSA_v2": case "ASTRA_v2":
return AEPSAv2(model_config) return ASTRAv2(model_config)
case "AEPSA_v3": case "ASTRA_v3":
return AEPSAv3(model_config) return ASTRAv3(model_config)

View File

@ -104,14 +104,14 @@ class Trainer:
loss = self.loss(output, label) loss = self.loss(output, label)
# 检查output和label的shape是否一致 # 检查output和label的shape是否一致
if output.shape == label.shape: # if output.shape == label.shape:
print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}") # print(f"✓ Test passed: output shape {output.shape} matches label shape {label.shape}")
import sys # import sys
sys.exit(0) # sys.exit(0)
else: # else:
print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}") # print(f"✗ Test failed: output shape {output.shape} does not match label shape {label.shape}")
import sys # import sys
sys.exit(1) # sys.exit(1)
# 反归一化 # 反归一化
d_output = self.scaler.inverse_transform(output) d_output = self.scaler.inverse_transform(output)