更新AEPSA框架,待修改文件

This commit is contained in:
czzhangheng 2025-11-11 21:29:52 +08:00
parent 6657743afe
commit a3fb2d5ec6
7 changed files with 320 additions and 0 deletions

8
.vscode/launch.json vendored
View File

@ -35,6 +35,14 @@
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/PEMS-BAY.yaml"
},
{
"name": "AEPSA-PEMSBAY",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/AEPSA/PEMS-BAY.yaml"
}
]
}

59
config/AEPSA/PEMS-BAY.yaml Executable file
View File

@ -0,0 +1,59 @@
basic:
dataset: "PEMS-BAY"
mode : "train"
device : "cuda:0"
model: "AEPSA"
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 325
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0
debug: false
output_dim: 1
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

58
config/AEPSA/PEMSD8.yaml Executable file
View File

@ -0,0 +1,58 @@
basic:
dataset: "PEMSD8"
mode : "train"
device : "cuda:0"
model: "AEPSA"
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 170
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 64
model:
pred_len: 12
seq_len: 12
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
train:
batch_size: 64
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
seed: 12
weight_decay: 0
debug: false
output_dim: 1
log_step: 2000
plot: false
mae_thresh: None
mape_thresh: 0.001

16
model/AEPSA/normalizer.py Normal file
View File

@ -0,0 +1,16 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
def gumbel_softmax(logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard)
if hard:
# 生成硬掩码
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
y_hard = torch.zeros_like(logits)
y_hard.scatter_(0, indices, 1)
return torch.squeeze(y_hard, dim=-1)
return torch.squeeze(y_soft, dim=-1)

View File

@ -0,0 +1,96 @@
import torch
import torch.nn as nn
from torch import Tensor
from math import sqrt
class ReplicationPad1d(nn.Module):
def __init__(self, padding) -> None:
super(ReplicationPad1d, self).__init__()
self.padding = padding
def forward(self, input: Tensor) -> Tensor:
replicate_padding = input[:, :, :, -1].unsqueeze(-1).repeat(1, 1, 1, self.padding[-1])
output = torch.cat([input, replicate_padding], dim=-1)
return output
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model, patch_num, input_dim):
super(TokenEmbedding, self).__init__()
padding = 1
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(
m.weight, mode='fan_in', nonlinearity='leaky_relu')
def forward(self, x):
b, n, m, pn, pl = x.shape # batch, node, feature, patch_num, patch_len
# 768,64,25
x = self.tokenConv(x.reshape(b*n, pl, m*pn)) # batch*node, patch_len, feature*patch_num
x = self.confusion_layer(x)
return x.reshape(b, n, -1)
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = ReplicationPad1d((0, stride))
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
n_vars = x.shape[2]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x_value_embed = self.value_embedding(x)
return self.dropout(x_value_embed), n_vars
class ReprogrammingLayer(nn.Module):
def __init__(self, d_model, n_heads, d_keys, d_llm, attention_dropout=0.1):
super(ReprogrammingLayer, self).__init__()
d_keys = d_keys or (d_model // n_heads)
self.query_projection = nn.Linear(d_model, d_keys * n_heads)
self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
self.n_heads = n_heads
self.dropout = nn.Dropout(attention_dropout)
def forward(self, target_embedding, source_embedding, value_embedding):
B, L, _ = target_embedding.shape
S, _ = source_embedding.shape
H = self.n_heads
target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
source_embedding = self.key_projection(source_embedding).view(S, H, -1)
value_embedding = self.value_projection(value_embedding).view(S, H, -1)
out = self.reprogramming(target_embedding, source_embedding, value_embedding)
out = out.reshape(B, L, -1)
return self.out_projection(out)
def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape
scale = 1. / sqrt(E)
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
return reprogramming_embedding

80
model/AEPSA/repst.py Normal file
View File

@ -0,0 +1,80 @@
import torch
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
from einops import rearrange
from model.REPST.normalizer import gumbel_softmax
from model.REPST.reprogramming import PatchEmbedding, ReprogrammingLayer
class repst(nn.Module):
def __init__(self, configs):
super(repst, self).__init__()
self.device = configs['device']
self.pred_len = configs['pred_len']
self.seq_len = configs['seq_len']
self.patch_len = configs['patch_len']
self.input_dim = configs['input_dim']
self.stride = configs['stride']
self.dropout = configs['dropout']
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.d_model = configs['d_model']
self.n_heads = configs['n_heads']
self.d_keys = None
self.d_llm = 768
self.patch_nums = int((self.seq_len - self.patch_len) / self.stride + 2)
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
self.gpts.h = self.gpts.h[:self.gpt_layers]
self.gpts.apply(self.reset_parameters)
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
if 'wpe' in name:
param.requires_grad = True
else:
param.requires_grad = False
def reset_parameters(self, module):
if hasattr(module, 'weight') and module.weight is not None:
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
if hasattr(module, 'bias') and module.bias is not None:
torch.nn.init.zeros_(module.bias)
def forward(self, x):
x = x[..., :1]
x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
masks = gumbel_softmax(self.mapping_layer.weight.data.permute(1,0))
source_embeddings = self.word_embeddings[masks==1]
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
return outputs

View File

@ -23,6 +23,7 @@ from model.ST_SSL.ST_SSL import STSSLModel
from model.STGNRDE.Make_model import make_model as make_nrde_model
from model.STAWnet.STAWnet import STAWnet
from model.REPST.repst import repst as REPST
from model.AEPSA.repst import repst as AEPSA
def model_selector(config):
@ -79,3 +80,5 @@ def model_selector(config):
return STAWnet(model_config)
case "REPST":
return REPST(model_config)
case "AEPSA":
return AEPSA(model_config)