兼容BeijingAirQuality。重构data,需要更新pip requirement

This commit is contained in:
czzhangheng 2025-11-20 20:19:17 +08:00
parent 9911caa3d8
commit a46edc79a5
10 changed files with 235 additions and 251 deletions

8
.vscode/launch.json vendored
View File

@ -52,6 +52,14 @@
"console": "integratedTerminal", "console": "integratedTerminal",
"args": "--config ./config/REPST/SolarEnergy.yaml" "args": "--config ./config/REPST/SolarEnergy.yaml"
}, },
{
"name": "BeijingAirQuality",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
},
{ {
"name": "AEPSA-PEMSBAY", "name": "AEPSA-PEMSBAY",
"type": "debugpy", "type": "debugpy",

12
.vscode/settings.json vendored
View File

@ -1,5 +1,11 @@
{ {
"python-envs.defaultEnvManager": "ms-python.python:system", "python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:pip", "python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": [] "python-envs.pythonProjects": [
{
"path": "data/SolarEnergy",
"envManager": "ms-python.python:system",
"packageManager": "ms-python.python:pip"
}
]
} }

View File

@ -0,0 +1,61 @@
basic:
dataset: "BeijingAirQuality"
mode : "train"
device : "cuda:1"
model: "REPST"
seed: 2023
data:
add_day_in_week: false
add_time_in_day: false
column_wise: false
days_per_week: 7
default_graph: true
horizon: 12
lag: 12
normalizer: std
num_nodes: 7
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 3
batch_size: 16
model:
pred_len: 12
seq_len: 12
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 3
output_dim: 3
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 3
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -7,6 +7,11 @@ def load_st_dataset(config):
# sample = config["data"]["sample"] # sample = config["data"]["sample"]
# output B, N, D # output B, N, D
match dataset: match dataset:
case "BeijingAirQuality":
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 36000, 7, 3
data = data.reshape(L, N, C)
case "PEMS-BAY": case "PEMS-BAY":
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
with h5py.File(data_path, 'r') as f: with h5py.File(data_path, 'r') as f:

View File

@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module):
return output return output
class TokenEmbedding(nn.Module): class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model, patch_num, input_dim): def __init__(self, c_in, d_model, patch_num, input_dim, output_dim):
super(TokenEmbedding, self).__init__() super(TokenEmbedding, self).__init__()
padding = 1 padding = 1
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False) kernel_size=3, padding=padding, padding_mode='circular', bias=False)
self.confusion_layer = nn.Linear(patch_num * input_dim, 1) self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim)
for m in self.modules(): for m in self.modules():
if isinstance(m, nn.Conv1d): if isinstance(m, nn.Conv1d):
@ -37,17 +37,16 @@ class TokenEmbedding(nn.Module):
class PatchEmbedding(nn.Module): class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim):
super(PatchEmbedding, self).__init__() super(PatchEmbedding, self).__init__()
# Patching # Patching
self.patch_len = patch_len self.patch_len = patch_len
self.stride = stride self.stride = stride
self.padding_patch_layer = ReplicationPad1d((0, stride)) self.padding_patch_layer = ReplicationPad1d((0, stride))
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim)
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
def forward(self, x): def forward(self, x):
n_vars = x.shape[2] n_vars = x.shape[2]
x = self.padding_patch_layer(x) x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)

View File

@ -19,6 +19,7 @@ class repst(nn.Module):
self.gpt_layers = configs['gpt_layers'] self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff'] self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path'] self.gpt_path = configs['gpt_path']
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num']) self.word_choice = GumbelSoftmax(configs['word_num'])
@ -31,7 +32,7 @@ class repst(nn.Module):
self.head_nf = self.d_ff * self.patch_nums self.head_nf = self.d_ff * self.patch_nums
# 词嵌入 # 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
# GPT2初始化 # GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
@ -41,7 +42,7 @@ class repst(nn.Module):
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0] self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1) self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm)
self.out_mlp = nn.Sequential( self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128), nn.Linear(self.d_llm, 128),
@ -62,7 +63,7 @@ class repst(nn.Module):
torch.nn.init.zeros_(module.bias) torch.nn.init.zeros_(module.bias)
def forward(self, x): def forward(self, x):
x = x[..., :1] x = x[..., :self.output_dim]
x_enc = rearrange(x, 'b t n c -> b n c t') x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc) enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
@ -79,25 +80,3 @@ class repst(nn.Module):
return outputs return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)

View File

@ -11,4 +11,5 @@ fastdtw
notebook notebook
torchcde torchcde
einops einops
transformers transformers
py7zr

2
run.py
View File

@ -15,7 +15,7 @@ def main():
args = init.init_device(args) args = init.init_device(args)
init.init_seed(args["basic"]["seed"]) init.init_seed(args["basic"]["seed"])
# Load model
model = init.init_model(args) model = init.init_model(args)
# Load dataset # Load dataset

View File

@ -1,228 +1,114 @@
import os import os, json, shutil, requests
import requests from urllib.parse import urlsplit
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
from tqdm import tqdm from tqdm import tqdm
import kagglehub
import py7zr
# 定义文件完整性信息的字典 # ---------- 1. 加载结构 JSON ----------
def load_structure_json(path="utils/dataset.json"):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# ---------- 2. 检测完整性 ----------
def check_and_download_data(): def detect_data_integrity(data_dir, expected, check_adj=False):
""" missing_adj, missing_main = False, False
检查 data 文件夹的完整性并根据缺失文件类型下载相应数据 if not os.path.isdir(data_dir): return True, True
""" for folder, files in expected.items():
current_working_dir = os.getcwd() # 获取当前工作目录 folder_path = os.path.join(data_dir, folder)
data_dir = os.path.join( if not os.path.isdir(folder_path):
current_working_dir, "data" if check_adj:
) # 假设 data 文件夹在当前工作目录下 missing_adj = True
expected_structure = {
"PEMS03": [
"PEMS03.csv",
"PEMS03.npz",
"PEMS03.txt",
"PEMS03_dtw_distance.npy",
"PEMS03_spatial_distance.npy",
],
"PEMS04": [
"PEMS04.csv",
"PEMS04.npz",
"PEMS04_dtw_distance.npy",
"PEMS04_spatial_distance.npy",
],
"PEMS07": [
"PEMS07.csv",
"PEMS07.npz",
"PEMS07_dtw_distance.npy",
"PEMS07_spatial_distance.npy",
],
"PEMS08": [
"PEMS08.csv",
"PEMS08.npz",
"PEMS08_dtw_distance.npy",
"PEMS08_spatial_distance.npy",
],
"PEMS-BAY": [
"adj_mx_bay.pkl",
"pems-bay-meta.h5",
"pems-bay.h5"
],
"METR-LA": [
"METR-LA.h5"
],
"SolarEnergy": [
]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue continue
missing_main = True
continue
existing = set(os.listdir(folder_path))
for f in files:
if f not in existing:
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
missing_adj = True
elif not check_adj:
missing_main = True
return missing_adj, missing_main
# 获取子文件夹中的实际文件列表 # ---------- 3. 下载 7z 并解压 ----------
actual_files = os.listdir(subfolder_path) def download_and_extract(url, dst_dir, max_retries=3):
os.makedirs(dst_dir, exist_ok=True)
filename = os.path.basename(urlsplit(url).path) or "download.7z"
file_path = os.path.join(dst_dir, filename)
for attempt in range(1, max_retries+1):
try:
# 下载
with requests.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
total = int(r.headers.get("content-length",0))
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
for chunk in r.iter_content(8192):
f.write(chunk)
bar.update(len(chunk))
# 解压 7z
with py7zr.SevenZipFile(file_path, mode='r') as archive:
archive.extractall(path=dst_dir)
os.remove(file_path)
return
except Exception as e:
if attempt==max_retries: raise RuntimeError("下载或解压失败")
print("错误,重试中...", e)
# 检查是否缺少文件 # ---------- 4. 下载 Kaggle 数据 ----------
for expected_file in expected_files: def download_kaggle_data(base_dir, dataset):
if expected_file not in actual_files: try:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") path = kagglehub.dataset_download(dataset)
if ( shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
"_dtw_distance.npy" in expected_file except Exception as e:
or "_spatial_distance.npy" in expected_file print("Kaggle 下载失败:", dataset, e)
):
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑 # ---------- 5. 整理目录 ----------
def rearrange_dir():
data_dir = os.path.join(os.getcwd(), "data")
nested = os.path.join(data_dir,"data")
if os.path.isdir(nested):
for item in os.listdir(nested):
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
else:
shutil.copy2(src, dst)
shutil.rmtree(nested)
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
for f in os.listdir(data_dir):
if kw in f.lower() and f.endswith((".h5",".pkl")):
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
solar = os.path.join(data_dir,"solar-energy")
if os.path.isdir(solar):
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
csv = os.path.join(solar,"solar_AL.csv")
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
shutil.rmtree(solar)
# ---------- 6. 主流程 ----------
def check_and_download_data():
cwd = os.getcwd()
data_dir = os.path.join(cwd,"data")
expected = load_structure_json()
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
if missing_adj: if missing_adj:
download_adj_data(current_dir) download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
if missing_main_files:
download_kaggle_data(current_dir, 'elmahy/pems-dataset')
download_kaggle_data(current_dir, 'scchuy/pemsbay')
download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset")
download_kaggle_data(current_dir, "wangshaoqi/solar-energy")
rearrange_dir()
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
if not os.path.isdir(baq_folder):
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
if missing_main:
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
download_kaggle_data(cwd, ds)
rearrange_dir()
return True return True
if __name__=="__main__":
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件并显示下载进度条
如果下载失败最多重试 max_retries
"""
url = "http://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, "wb") as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(
f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。"
)
def download_kaggle_data(current_dir, kaggle_path):
"""
下载 KaggleHub 数据集并将数据直接移动到当前工作目录的 data 文件夹
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print(f"正在下载 {kaggle_path} 数据集...")
path = kagglehub.dataset_download(kaggle_path)
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件
shutil.copytree(path, destination_path, dirs_exist_ok=True)
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
def rearrange_dir():
"""
data/data 中的文件合并到上级目录并删除 data/data 目录
"""
data_dir = os.path.join(os.getcwd(), "data")
nested_data_dir = os.path.join(data_dir, "data")
if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir):
for item in os.listdir(nested_data_dir):
source_path = os.path.join(nested_data_dir, item)
destination_path = os.path.join(data_dir, item)
if os.path.isdir(source_path):
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
else:
shutil.copy2(source_path, destination_path)
shutil.rmtree(nested_data_dir)
# 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹
pems_bay_dir = os.path.join(data_dir, "PEMS-BAY")
os.makedirs(pems_bay_dir, exist_ok=True)
for item in os.listdir(data_dir):
if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
source_path = os.path.join(data_dir, item)
destination_path = os.path.join(pems_bay_dir, item)
shutil.move(source_path, destination_path)
# metr-la
metrla_dir = os.path.join(data_dir, "METR-LA")
os.makedirs(metrla_dir, exist_ok=True)
for item in os.listdir(data_dir):
if "metr" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
source_path = os.path.join(data_dir, item)
destination_path = os.path.join(metrla_dir, item)
shutil.move(source_path, destination_path)
# solar-energy
solar_src = os.path.join(data_dir, "solar-energy")
solar_sub = os.path.join(solar_src, "solar_AL.txt")
solar_csv = os.path.join(solar_src, "solar_AL.csv")
solar_dst_dir = os.path.join(data_dir,"SolarEnergy")
solar_dst_csv = os.path.join(solar_dst_dir, "SolarEnergy.csv")
if os.path.isdir(solar_sub): shutil.rmtree(solar_sub)
if os.path.isdir(solar_src): os.rename(solar_src, solar_dst_dir)
if os.path.isfile(solar_csv.replace(solar_src, solar_dst_dir)):
os.rename(solar_csv.replace(solar_src, solar_dst_dir), solar_dst_csv)
# 主程序
if __name__ == "__main__":
check_and_download_data() check_and_download_data()

39
utils/dataset.json Normal file
View File

@ -0,0 +1,39 @@
{
"PEMS03": [
"PEMS03.csv",
"PEMS03.npz",
"PEMS03.txt",
"PEMS03_dtw_distance.npy",
"PEMS03_spatial_distance.npy"
],
"PEMS04": [
"PEMS04.csv",
"PEMS04.npz",
"PEMS04_dtw_distance.npy",
"PEMS04_spatial_distance.npy"
],
"PEMS07": [
"PEMS07.csv",
"PEMS07.npz",
"PEMS07_dtw_distance.npy",
"PEMS07_spatial_distance.npy"
],
"PEMS08": [
"PEMS08.csv",
"PEMS08.npz",
"PEMS08_dtw_distance.npy",
"PEMS08_spatial_distance.npy"
],
"PEMS-BAY": [
"adj_mx_bay.pkl",
"pems-bay-meta.h5",
"pems-bay.h5"
],
"METR-LA": [
"METR-LA.h5"
],
"SolarEnergy": [
"SolarEnergy.csv"
],
"BeijingAirQuality": ["data.dat", "desc.json"]
}