From 2685d049d79f733ebe83c6449333217a189b302f Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 18:51:47 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9pems-bay?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 9 +++++ config/REPST/PEMS-BAY.yaml | 58 +++++++++++++++++++++++++++++ dataloader/PeMSDdataloader.py | 4 ++ lib/Download_data.py | 69 +++++++++++++++++++++++++++-------- 4 files changed, 124 insertions(+), 16 deletions(-) create mode 100755 config/REPST/PEMS-BAY.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 3f71813..752af1d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,7 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { "name": "EXP_PEMSD8", "type": "debugpy", @@ -19,6 +20,14 @@ "program": "run.py", "console": "integratedTerminal", "args": "--config ./config/REPST/PEMSD8.yaml" + }, + { + "name": "REPST-PEMSBAY", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/PEMS-BAY.yaml" } ] } \ No newline at end of file diff --git a/config/REPST/PEMS-BAY.yaml b/config/REPST/PEMS-BAY.yaml new file mode 100755 index 0000000..1333cb8 --- /dev/null +++ b/config/REPST/PEMS-BAY.yaml @@ -0,0 +1,58 @@ +basic: + dataset: "PEMS-BAY" + mode : "train" + device : "cuda:0" + model: "REPST" + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 325 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + seed: 12 + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/PeMSDdataloader.py b/dataloader/PeMSDdataloader.py index f6764a7..f30ef4f 100755 --- a/dataloader/PeMSDdataloader.py +++ b/dataloader/PeMSDdataloader.py @@ -118,6 +118,10 @@ def load_st_dataset(config): sample = config["data"]["sample"] # output B, N, D match dataset: + case "PEMS-BAY": + data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") + with h5py.File(data_path, 'r') as f: + data = f['speed']['block0_values'][:] case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") data = np.load(data_path)["data"][ diff --git a/lib/Download_data.py b/lib/Download_data.py index 6b2c648..0dae1fd 100755 --- a/lib/Download_data.py +++ b/lib/Download_data.py @@ -43,6 +43,11 @@ def check_and_download_data(): "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy", ], + "PEMS-BAY": [ + "adj_mx_bay.pkl", + "pems-bay-meta.h5", + "pems-bay.h5" + ] } current_dir = os.getcwd() # 获取当前工作目录 @@ -90,7 +95,12 @@ def check_and_download_data(): if missing_adj: download_adj_data(current_dir) if missing_main_files: - download_kaggle_data(current_dir) + download_kaggle_data(current_dir, 'elmahy/pems-dataset') + download_kaggle_data(current_dir, 'scchuy/pemsbay') + + rearrange_dir() + + return True @@ -143,32 +153,59 @@ def download_adj_data(current_dir, max_retries=3): ) -def download_kaggle_data(current_dir): +def download_kaggle_data(current_dir, kaggle_path): """ - 下载 KaggleHub 数据集,并将 data 文件夹合并到当前工作目录。 + 下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。 如果目标文件夹已存在,会覆盖冲突的文件。 """ try: - print("正在下载 PEMS 数据集...") - path = kagglehub.dataset_download("elmahy/pems-dataset") + print(f"正在下载 {kaggle_path} 数据集...") + path = kagglehub.dataset_download(kaggle_path) # print("Path to KaggleHub dataset files:", path) if os.path.exists(path): - data_folder_path = os.path.join(path, "data") - if os.path.exists(data_folder_path): - destination_path = os.path.join(current_dir, "data") - - # 使用 shutil.copytree 合并文件夹,覆盖冲突的文件 - shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True) - # print(f"data 文件夹已合并到: {destination_path}") - # else: - # print("未找到 data 文件夹,跳过合并操作。") - # else: - # print("未找到 KaggleHub 数据集路径,跳过处理。") + destination_path = os.path.join(current_dir, "data") + # 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件 + shutil.copytree(path, destination_path, dirs_exist_ok=True) except Exception as e: print(f"下载或处理 KaggleHub 数据集时出错: {e}") + +def rearrange_dir(): + """ + 将 data/data 中的文件合并到上级目录,并删除 data/data 目录。 + """ + data_dir = os.path.join(os.getcwd(), "data") + nested_data_dir = os.path.join(data_dir, "data") + + if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir): + for item in os.listdir(nested_data_dir): + source_path = os.path.join(nested_data_dir, item) + destination_path = os.path.join(data_dir, item) + + if os.path.isdir(source_path): + shutil.copytree(source_path, destination_path, dirs_exist_ok=True) + else: + shutil.copy2(source_path, destination_path) + + shutil.rmtree(nested_data_dir) + # print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。") + + # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 + pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") + os.makedirs(pems_bay_dir, exist_ok=True) + + for item in os.listdir(data_dir): + if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): + source_path = os.path.join(data_dir, item) + destination_path = os.path.join(pems_bay_dir, item) + shutil.move(source_path, destination_path) + + # print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。") + + # 主程序 if __name__ == "__main__": check_and_download_data() + # rearrange_dir()