From 475a4788cd0a69eee2dea42963bc8330b9a73788 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 23 Nov 2025 19:04:50 +0800 Subject: [PATCH] =?UTF-8?q?=E5=85=BC=E5=AE=B9BJTaxi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 15 +++++---- config/REPST/BJTaxi-Inflow.yaml | 60 +++++++++++++++++++++++++++++++++ dataloader/data_selector.py | 17 ++++++++++ utils/Download_data.py | 22 +++++------- 4 files changed, 93 insertions(+), 21 deletions(-) create mode 100755 config/REPST/BJTaxi-Inflow.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 771325c..4f992f0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,13 +4,6 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - { - "name": "Python Debugger: Current File", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal" - }, { "name": "STID_PEMS-BAY", "type": "debugpy", @@ -35,6 +28,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMSD8.yaml" }, + { + "name": "REPST-BJTaxi-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" + }, { "name": "REPST-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml new file mode 100755 index 0000000..37577a2 --- /dev/null +++ b/config/REPST/BJTaxi-Inflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "BJTaxi-InFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 45987d6..224b6fc 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -55,6 +55,10 @@ def load_st_dataset(config): case "SD": data_path = os.path.join("./data/SD/data.npz") data = np.load(data_path)["data"][:, :, 0].astype(np.float32) + case "BJTaxi-InFlow": + data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) + case "BJTaxi-OutFlow": + data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) case _: raise ValueError(f"Unsupported dataset: {dataset}") @@ -65,3 +69,16 @@ def load_st_dataset(config): print("加载 %s 数据集中... " % dataset) # return data[::sample] return data + +def read_BeijingTaxi(): + files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", + "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] + all_data = [] + for file in files: + data_path = os.path.join(f"./data/BeijingTaxi/{file}") + data = np.load(data_path) + all_data.append(data) + all_data = np.concatenate(all_data, axis=0) + time_num = all_data.shape[0] + all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) + return all_data \ No newline at end of file diff --git a/utils/Download_data.py b/utils/Download_data.py index e0ce6ce..539a804 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -4,12 +4,8 @@ from tqdm import tqdm import kagglehub import py7zr -# ---------- 1. 加载结构 JSON ---------- -def load_structure_json(path="utils/dataset.json"): - with open(path, "r", encoding="utf-8") as f: - return json.load(f) -# ---------- 2. 检测完整性 ---------- +# ---------- 1. 检测完整性 ---------- def detect_data_integrity(data_dir, expected): missing_list = [] if not os.path.isdir(data_dir): @@ -53,7 +49,7 @@ def detect_data_integrity(data_dir, expected): # print(f"缺失数据集:{missing_list}") return missing_list -# ---------- 3. 下载 7z 并解压 ---------- +# ---------- 2. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): os.makedirs(dst_dir, exist_ok=True) filename = os.path.basename(urlsplit(url).path) or "download.7z" @@ -77,7 +73,7 @@ def download_and_extract(url, dst_dir, max_retries=3): if attempt==max_retries: raise RuntimeError("下载或解压失败") print("错误,重试中...", e) -# ---------- 4. 下载 Kaggle 数据 ---------- +# ---------- 3. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: print(f"Downloading kaggle dataset : {dataset}") @@ -86,7 +82,7 @@ def download_kaggle_data(base_dir, dataset): except Exception as e: print("Kaggle 下载失败:", dataset, e) -# ---------- 5. 下载 GitHub 数据 ---------- +# ---------- 4. 下载 GitHub 数据 ---------- def download_github_data(file_path, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -136,13 +132,13 @@ def rearrange_dir(): # ---------- 6. 主流程 ---------- def check_and_download_data(): + # 加载结构文件,检测缺失数据集 cwd = os.getcwd() data_dir = os.path.join(cwd,"data") - file_tree = load_structure_json() - - # 执行一次检测,获取所有缺失项 + with open("utils/dataset.json", "r", encoding="utf-8") as f: + file_tree = json.load(f) missing_list = detect_data_integrity(data_dir, file_tree) - print(f"缺失数据集:{missing_list}") + # print(f"缺失数据集:{missing_list}") # 检查并下载adj数据 if "adj" in missing_list: @@ -167,7 +163,6 @@ def check_and_download_data(): missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载pems, bay, metr-la, solar-energy数据 - # 定义数据集名称到Kaggle数据集的映射 kaggle_map = { "PEMS03": "elmahy/pems-dataset", "PEMS04": "elmahy/pems-dataset", @@ -178,7 +173,6 @@ def check_and_download_data(): "SolarEnergy": "wangshaoqi/solar-energy" } - # 检查是否有需要从Kaggle下载的数据集 # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 downloaded_kaggle_datasets = set()