212 lines
7.4 KiB
Python
Executable File
212 lines
7.4 KiB
Python
Executable File
import os
|
||
import requests
|
||
import zipfile
|
||
import shutil
|
||
import kagglehub # 假设 kagglehub 是一个可用的库
|
||
from tqdm import tqdm
|
||
|
||
# 定义文件完整性信息的字典
|
||
|
||
|
||
def check_and_download_data():
|
||
"""
|
||
检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。
|
||
"""
|
||
current_working_dir = os.getcwd() # 获取当前工作目录
|
||
data_dir = os.path.join(
|
||
current_working_dir, "data"
|
||
) # 假设 data 文件夹在当前工作目录下
|
||
|
||
expected_structure = {
|
||
"PEMS03": [
|
||
"PEMS03.csv",
|
||
"PEMS03.npz",
|
||
"PEMS03.txt",
|
||
"PEMS03_dtw_distance.npy",
|
||
"PEMS03_spatial_distance.npy",
|
||
],
|
||
"PEMS04": [
|
||
"PEMS04.csv",
|
||
"PEMS04.npz",
|
||
"PEMS04_dtw_distance.npy",
|
||
"PEMS04_spatial_distance.npy",
|
||
],
|
||
"PEMS07": [
|
||
"PEMS07.csv",
|
||
"PEMS07.npz",
|
||
"PEMS07_dtw_distance.npy",
|
||
"PEMS07_spatial_distance.npy",
|
||
],
|
||
"PEMS08": [
|
||
"PEMS08.csv",
|
||
"PEMS08.npz",
|
||
"PEMS08_dtw_distance.npy",
|
||
"PEMS08_spatial_distance.npy",
|
||
],
|
||
"PEMS-BAY": [
|
||
"adj_mx_bay.pkl",
|
||
"pems-bay-meta.h5",
|
||
"pems-bay.h5"
|
||
]
|
||
}
|
||
|
||
current_dir = os.getcwd() # 获取当前工作目录
|
||
missing_adj = False
|
||
missing_main_files = False
|
||
|
||
# 检查 data 文件夹是否存在
|
||
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
|
||
# print(f"目录 {data_dir} 不存在。")
|
||
print("正在下载所有必要的数据文件...")
|
||
missing_adj = True
|
||
missing_main_files = True
|
||
else:
|
||
# 检查根目录下的 get_adj.py 文件
|
||
if "get_adj.py" not in os.listdir(data_dir):
|
||
# print(f"根目录下缺少文件 get_adj.py。")
|
||
missing_adj = True
|
||
|
||
# 遍历预期的文件结构
|
||
for subfolder, expected_files in expected_structure.items():
|
||
subfolder_path = os.path.join(data_dir, subfolder)
|
||
|
||
# 检查子文件夹是否存在
|
||
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
|
||
# print(f"子文件夹 {subfolder} 不存在。")
|
||
missing_main_files = True
|
||
continue
|
||
|
||
# 获取子文件夹中的实际文件列表
|
||
actual_files = os.listdir(subfolder_path)
|
||
|
||
# 检查是否缺少文件
|
||
for expected_file in expected_files:
|
||
if expected_file not in actual_files:
|
||
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
|
||
if (
|
||
"_dtw_distance.npy" in expected_file
|
||
or "_spatial_distance.npy" in expected_file
|
||
):
|
||
missing_adj = True
|
||
else:
|
||
missing_main_files = True
|
||
|
||
# 根据缺失文件类型调用下载逻辑
|
||
if missing_adj:
|
||
download_adj_data(current_dir)
|
||
if missing_main_files:
|
||
download_kaggle_data(current_dir, 'elmahy/pems-dataset')
|
||
download_kaggle_data(current_dir, 'scchuy/pemsbay')
|
||
|
||
rearrange_dir()
|
||
|
||
|
||
|
||
return True
|
||
|
||
|
||
def download_adj_data(current_dir, max_retries=3):
|
||
"""
|
||
下载并解压 adj.zip 文件,并显示下载进度条。
|
||
如果下载失败,最多重试 max_retries 次。
|
||
"""
|
||
url = "http://code.zhang-heng.com/static/adj.zip"
|
||
retries = 0
|
||
|
||
while retries <= max_retries:
|
||
try:
|
||
print(f"正在从 {url} 下载邻接矩阵文件...")
|
||
response = requests.get(url, stream=True)
|
||
|
||
if response.status_code == 200:
|
||
total_size = int(response.headers.get("content-length", 0))
|
||
block_size = 1024 # 1KB
|
||
t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度")
|
||
|
||
zip_file_path = os.path.join(current_dir, "adj.zip")
|
||
with open(zip_file_path, "wb") as f:
|
||
for data in response.iter_content(block_size):
|
||
f.write(data)
|
||
t.update(len(data))
|
||
t.close()
|
||
|
||
# print("下载完成,文件已保存到:", zip_file_path)
|
||
|
||
if os.path.exists(zip_file_path):
|
||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
||
zip_ref.extractall(current_dir)
|
||
# print("数据集已解压到:", current_dir)
|
||
os.remove(zip_file_path) # 删除zip文件
|
||
else:
|
||
print("未找到下载的zip文件,跳过解压。")
|
||
break # 下载成功,退出循环
|
||
else:
|
||
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
|
||
except Exception as e:
|
||
print(f"下载或解压数据集时出错: {e}")
|
||
print("如果链接无效,请检查URL的合法性或稍后重试。")
|
||
|
||
retries += 1
|
||
if retries > max_retries:
|
||
raise Exception(
|
||
f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。"
|
||
)
|
||
|
||
|
||
def download_kaggle_data(current_dir, kaggle_path):
|
||
"""
|
||
下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。
|
||
如果目标文件夹已存在,会覆盖冲突的文件。
|
||
"""
|
||
try:
|
||
print(f"正在下载 {kaggle_path} 数据集...")
|
||
path = kagglehub.dataset_download(kaggle_path)
|
||
# print("Path to KaggleHub dataset files:", path)
|
||
|
||
if os.path.exists(path):
|
||
destination_path = os.path.join(current_dir, "data")
|
||
# 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件
|
||
shutil.copytree(path, destination_path, dirs_exist_ok=True)
|
||
except Exception as e:
|
||
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
|
||
|
||
|
||
|
||
def rearrange_dir():
|
||
"""
|
||
将 data/data 中的文件合并到上级目录,并删除 data/data 目录。
|
||
"""
|
||
data_dir = os.path.join(os.getcwd(), "data")
|
||
nested_data_dir = os.path.join(data_dir, "data")
|
||
|
||
if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir):
|
||
for item in os.listdir(nested_data_dir):
|
||
source_path = os.path.join(nested_data_dir, item)
|
||
destination_path = os.path.join(data_dir, item)
|
||
|
||
if os.path.isdir(source_path):
|
||
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
|
||
else:
|
||
shutil.copy2(source_path, destination_path)
|
||
|
||
shutil.rmtree(nested_data_dir)
|
||
# print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。")
|
||
|
||
# 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹
|
||
pems_bay_dir = os.path.join(data_dir, "PEMS-BAY")
|
||
os.makedirs(pems_bay_dir, exist_ok=True)
|
||
|
||
for item in os.listdir(data_dir):
|
||
if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
|
||
source_path = os.path.join(data_dir, item)
|
||
destination_path = os.path.join(pems_bay_dir, item)
|
||
shutil.move(source_path, destination_path)
|
||
|
||
# print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。")
|
||
|
||
|
||
# 主程序
|
||
if __name__ == "__main__":
|
||
check_and_download_data()
|
||
# rearrange_dir()
|