import os import requests import zipfile import shutil import kagglehub # 假设 kagglehub 是一个可用的库 from tqdm import tqdm # 定义文件完整性信息的字典 def check_and_download_data(): """ 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 """ current_working_dir = os.getcwd() # 获取当前工作目录 data_dir = os.path.join( current_working_dir, "data" ) # 假设 data 文件夹在当前工作目录下 expected_structure = { "PEMS03": [ "PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy", ], "PEMS04": [ "PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy", ], "PEMS07": [ "PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy", ], "PEMS08": [ "PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy", ], "PEMS-BAY": [ "adj_mx_bay.pkl", "pems-bay-meta.h5", "pems-bay.h5" ] } current_dir = os.getcwd() # 获取当前工作目录 missing_adj = False missing_main_files = False # 检查 data 文件夹是否存在 if not os.path.exists(data_dir) or not os.path.isdir(data_dir): # print(f"目录 {data_dir} 不存在。") print("正在下载所有必要的数据文件...") missing_adj = True missing_main_files = True else: # 检查根目录下的 get_adj.py 文件 if "get_adj.py" not in os.listdir(data_dir): # print(f"根目录下缺少文件 get_adj.py。") missing_adj = True # 遍历预期的文件结构 for subfolder, expected_files in expected_structure.items(): subfolder_path = os.path.join(data_dir, subfolder) # 检查子文件夹是否存在 if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): # print(f"子文件夹 {subfolder} 不存在。") missing_main_files = True continue # 获取子文件夹中的实际文件列表 actual_files = os.listdir(subfolder_path) # 检查是否缺少文件 for expected_file in expected_files: if expected_file not in actual_files: # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") if ( "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file ): missing_adj = True else: missing_main_files = True # 根据缺失文件类型调用下载逻辑 if missing_adj: download_adj_data(current_dir) if missing_main_files: download_kaggle_data(current_dir, 'elmahy/pems-dataset') download_kaggle_data(current_dir, 'scchuy/pemsbay') rearrange_dir() return True def download_adj_data(current_dir, max_retries=3): """ 下载并解压 adj.zip 文件,并显示下载进度条。 如果下载失败,最多重试 max_retries 次。 """ url = "http://code.zhang-heng.com/static/adj.zip" retries = 0 while retries <= max_retries: try: print(f"正在从 {url} 下载邻接矩阵文件...") response = requests.get(url, stream=True) if response.status_code == 200: total_size = int(response.headers.get("content-length", 0)) block_size = 1024 # 1KB t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度") zip_file_path = os.path.join(current_dir, "adj.zip") with open(zip_file_path, "wb") as f: for data in response.iter_content(block_size): f.write(data) t.update(len(data)) t.close() # print("下载完成,文件已保存到:", zip_file_path) if os.path.exists(zip_file_path): with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(current_dir) # print("数据集已解压到:", current_dir) os.remove(zip_file_path) # 删除zip文件 else: print("未找到下载的zip文件,跳过解压。") break # 下载成功,退出循环 else: print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") except Exception as e: print(f"下载或解压数据集时出错: {e}") print("如果链接无效,请检查URL的合法性或稍后重试。") retries += 1 if retries > max_retries: raise Exception( f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。" ) def download_kaggle_data(current_dir, kaggle_path): """ 下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。 如果目标文件夹已存在,会覆盖冲突的文件。 """ try: print(f"正在下载 {kaggle_path} 数据集...") path = kagglehub.dataset_download(kaggle_path) # print("Path to KaggleHub dataset files:", path) if os.path.exists(path): destination_path = os.path.join(current_dir, "data") # 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件 shutil.copytree(path, destination_path, dirs_exist_ok=True) except Exception as e: print(f"下载或处理 KaggleHub 数据集时出错: {e}") def rearrange_dir(): """ 将 data/data 中的文件合并到上级目录,并删除 data/data 目录。 """ data_dir = os.path.join(os.getcwd(), "data") nested_data_dir = os.path.join(data_dir, "data") if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir): for item in os.listdir(nested_data_dir): source_path = os.path.join(nested_data_dir, item) destination_path = os.path.join(data_dir, item) if os.path.isdir(source_path): shutil.copytree(source_path, destination_path, dirs_exist_ok=True) else: shutil.copy2(source_path, destination_path) shutil.rmtree(nested_data_dir) # print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。") # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") os.makedirs(pems_bay_dir, exist_ok=True) for item in os.listdir(data_dir): if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): source_path = os.path.join(data_dir, item) destination_path = os.path.join(pems_bay_dir, item) shutil.move(source_path, destination_path) # print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。") # 主程序 if __name__ == "__main__": check_and_download_data() # rearrange_dir()