diff --git a/.vscode/launch.json b/.vscode/launch.json index 96f2427..771325c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,7 +4,14 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ -{ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { "name": "STID_PEMS-BAY", "type": "debugpy", "request": "launch", diff --git a/utils/Download_data.py b/utils/Download_data.py index beec5de..e0ce6ce 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -10,25 +10,48 @@ def load_structure_json(path="utils/dataset.json"): return json.load(f) # ---------- 2. 检测完整性 ---------- -def detect_data_integrity(data_dir, expected, check_adj=False): - missing_adj, missing_main = False, False - if not os.path.isdir(data_dir): return True, True +def detect_data_integrity(data_dir, expected): + missing_list = [] + if not os.path.isdir(data_dir): + # 如果数据目录不存在,则所有数据集都缺失 + missing_list.extend(expected.keys()) + # 标记adj也缺失 + missing_list.append("adj") + return missing_list + + # 检查adj相关文件(距离矩阵文件) + has_missing_adj = False + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if os.path.isdir(folder_path): + existing = set(os.listdir(folder_path)) + for f in files: + if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_adj = True + break + if has_missing_adj: + missing_list.append("adj") + + # 检查数据集主文件 for folder, files in expected.items(): folder_path = os.path.join(data_dir, folder) if not os.path.isdir(folder_path): - if check_adj: - missing_adj = True - continue - missing_main = True + missing_list.append(folder) continue + existing = set(os.listdir(folder_path)) + has_missing_file = False + for f in files: - if f not in existing: - if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): - missing_adj = True - elif not check_adj: - missing_main = True - return missing_adj, missing_main + # 跳过距离矩阵文件,已经在上面检查过了 + if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_file = True + + if has_missing_file and folder not in missing_list: + missing_list.append(folder) + + # print(f"缺失数据集:{missing_list}") + return missing_list # ---------- 3. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): @@ -57,11 +80,34 @@ def download_and_extract(url, dst_dir, max_retries=3): # ---------- 4. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: + print(f"Downloading kaggle dataset : {dataset}") path = kagglehub.dataset_download(dataset) shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) except Exception as e: print("Kaggle 下载失败:", dataset, e) +# ---------- 5. 下载 GitHub 数据 ---------- +def download_github_data(file_path, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + # raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + response = requests.head(raw_url, allow_redirects=True) + if response.status_code != 200: + print(f"Failed to get file size for {raw_url}. Status code:", response.status_code) + return + + file_size = int(response.headers.get('Content-Length', 0)) + response = requests.get(raw_url, stream=True, allow_redirects=True) + file_name = os.path.basename(file_path) + file_path_to_save = os.path.join(save_dir, file_name) + with open(file_path_to_save, 'wb') as f: + with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + # ---------- 5. 整理目录 ---------- def rearrange_dir(): data_dir = os.path.join(os.getcwd(), "data") @@ -92,21 +138,57 @@ def rearrange_dir(): def check_and_download_data(): cwd = os.getcwd() data_dir = os.path.join(cwd,"data") - expected = load_structure_json() + file_tree = load_structure_json() - missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) - if missing_adj: + # 执行一次检测,获取所有缺失项 + missing_list = detect_data_integrity(data_dir, file_tree) + print(f"缺失数据集:{missing_list}") + + # 检查并下载adj数据 + if "adj" in missing_list: download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) + # 下载后从缺失列表中移除adj + missing_list.remove("adj") - baq_folder = os.path.join(data_dir,"BeijingAirQuality") - baq_folder2 = os.path.join(data_dir,"AirQuality") - if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2): + # 检查BeijingAirQuality和AirQuality + if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list: download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) + + # 检查并下载TaxiBJ数据 + if "TaxiBJ" in missing_list: + taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi") + taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy'] + for file in taxibj_files: + file_path = f"Datasets/TaxiBJ/{file}" + download_github_data(file_path, taxi_bj_floder) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) - _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) - if missing_main: - for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: - download_kaggle_data(cwd, ds) + # 检查并下载pems, bay, metr-la, solar-energy数据 + # 定义数据集名称到Kaggle数据集的映射 + kaggle_map = { + "PEMS03": "elmahy/pems-dataset", + "PEMS04": "elmahy/pems-dataset", + "PEMS07": "elmahy/pems-dataset", + "PEMS08": "elmahy/pems-dataset", + "PEMS-BAY": "scchuy/pemsbay", + "METR-LA": "annnnguyen/metr-la-dataset", + "SolarEnergy": "wangshaoqi/solar-energy" + } + + # 检查是否有需要从Kaggle下载的数据集 + # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 + downloaded_kaggle_datasets = set() + + for dataset, kaggle_ds in kaggle_map.items(): + if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets: + download_kaggle_data(cwd, kaggle_ds) + # 将已下载的数据集添加到集合中 + downloaded_kaggle_datasets.add(kaggle_ds) + # 下载一个数据集后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) rearrange_dir() return True diff --git a/utils/dataset.json b/utils/dataset.json index a778eff..55f46c3 100644 --- a/utils/dataset.json +++ b/utils/dataset.json @@ -36,5 +36,6 @@ "SolarEnergy.csv" ], "BeijingAirQuality": ["data.dat", "desc.json"], - "AirQuality": ["data.dat"] + "AirQuality": ["data.dat"], + "BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] }