TrafficWheel/lib/Download_data.py

146 lines
5.8 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import requests
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
from tqdm import tqdm
# 定义文件完整性信息的字典
def check_and_download_data():
"""
检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。
"""
current_working_dir = os.getcwd() # 获取当前工作目录
data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下
expected_structure = {
"PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"],
"PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"],
"PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"],
"PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 检查根目录下的 get_adj.py 文件
if "get_adj.py" not in os.listdir(data_dir):
# print(f"根目录下缺少文件 get_adj.py。")
missing_adj = True
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue
# 获取子文件夹中的实际文件列表
actual_files = os.listdir(subfolder_path)
# 检查是否缺少文件
for expected_file in expected_files:
if expected_file not in actual_files:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file:
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑
if missing_adj:
download_adj_data(current_dir)
if missing_main_files:
download_kaggle_data(current_dir)
return True
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件,并显示下载进度条。
如果下载失败,最多重试 max_retries 次。
"""
url = "https://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, 'wb') as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。")
def download_kaggle_data(current_dir):
"""
下载 KaggleHub 数据集,并将 data 文件夹合并到当前工作目录。
如果目标文件夹已存在,会覆盖冲突的文件。
"""
try:
print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
data_folder_path = os.path.join(path, "data")
if os.path.exists(data_folder_path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 合并文件夹,覆盖冲突的文件
shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True)
# print(f"data 文件夹已合并到: {destination_path}")
# else:
# print("未找到 data 文件夹,跳过合并操作。")
# else:
# print("未找到 KaggleHub 数据集路径,跳过处理。")
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
# 主程序
if __name__ == "__main__":
check_and_download_data()