import os import requests import zipfile import shutil import kagglehub # 假设 kagglehub 是一个可用的库 from tqdm import tqdm # 定义文件完整性信息的字典 def check_and_download_data(): """ 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 """ current_working_dir = os.getcwd() # 获取当前工作目录 data_dir = os.path.join( current_working_dir, "data" ) # 假设 data 文件夹在当前工作目录下 expected_structure = { "PEMS03": [ "PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy", ], "PEMS04": [ "PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy", ], "PEMS07": [ "PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy", ], "PEMS08": [ "PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy", ], } current_dir = os.getcwd() # 获取当前工作目录 missing_adj = False missing_main_files = False # 检查 data 文件夹是否存在 if not os.path.exists(data_dir) or not os.path.isdir(data_dir): # print(f"目录 {data_dir} 不存在。") print("正在下载所有必要的数据文件...") missing_adj = True missing_main_files = True else: # 检查根目录下的 get_adj.py 文件 if "get_adj.py" not in os.listdir(data_dir): # print(f"根目录下缺少文件 get_adj.py。") missing_adj = True # 遍历预期的文件结构 for subfolder, expected_files in expected_structure.items(): subfolder_path = os.path.join(data_dir, subfolder) # 检查子文件夹是否存在 if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): # print(f"子文件夹 {subfolder} 不存在。") missing_main_files = True continue # 获取子文件夹中的实际文件列表 actual_files = os.listdir(subfolder_path) # 检查是否缺少文件 for expected_file in expected_files: if expected_file not in actual_files: # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") if ( "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file ): missing_adj = True else: missing_main_files = True # 根据缺失文件类型调用下载逻辑 if missing_adj: download_adj_data(current_dir) if missing_main_files: download_kaggle_data(current_dir) return True def download_adj_data(current_dir, max_retries=3): """ 下载并解压 adj.zip 文件,并显示下载进度条。 如果下载失败,最多重试 max_retries 次。 """ url = "http://code.zhang-heng.com/static/adj.zip" retries = 0 while retries <= max_retries: try: print(f"正在从 {url} 下载邻接矩阵文件...") response = requests.get(url, stream=True) if response.status_code == 200: total_size = int(response.headers.get("content-length", 0)) block_size = 1024 # 1KB t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度") zip_file_path = os.path.join(current_dir, "adj.zip") with open(zip_file_path, "wb") as f: for data in response.iter_content(block_size): f.write(data) t.update(len(data)) t.close() # print("下载完成,文件已保存到:", zip_file_path) if os.path.exists(zip_file_path): with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(current_dir) # print("数据集已解压到:", current_dir) os.remove(zip_file_path) # 删除zip文件 else: print("未找到下载的zip文件,跳过解压。") break # 下载成功,退出循环 else: print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") except Exception as e: print(f"下载或解压数据集时出错: {e}") print("如果链接无效,请检查URL的合法性或稍后重试。") retries += 1 if retries > max_retries: raise Exception( f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。" ) def download_kaggle_data(current_dir): """ 下载 KaggleHub 数据集,并将 data 文件夹合并到当前工作目录。 如果目标文件夹已存在,会覆盖冲突的文件。 """ try: print("正在下载 PEMS 数据集...") path = kagglehub.dataset_download("elmahy/pems-dataset") # print("Path to KaggleHub dataset files:", path) if os.path.exists(path): data_folder_path = os.path.join(path, "data") if os.path.exists(data_folder_path): destination_path = os.path.join(current_dir, "data") # 使用 shutil.copytree 合并文件夹,覆盖冲突的文件 shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True) # print(f"data 文件夹已合并到: {destination_path}") # else: # print("未找到 data 文件夹,跳过合并操作。") # else: # print("未找到 KaggleHub 数据集路径,跳过处理。") except Exception as e: print(f"下载或处理 KaggleHub 数据集时出错: {e}") # 主程序 if __name__ == "__main__": check_and_download_data()