TrafficWheel/utils/Download_data.py

115 lines
4.5 KiB
Python
Executable File

import os, json, shutil, requests
from urllib.parse import urlsplit
from tqdm import tqdm
import kagglehub
import py7zr
# ---------- 1. 加载结构 JSON ----------
def load_structure_json(path="utils/dataset.json"):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# ---------- 2. 检测完整性 ----------
def detect_data_integrity(data_dir, expected, check_adj=False):
missing_adj, missing_main = False, False
if not os.path.isdir(data_dir): return True, True
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if not os.path.isdir(folder_path):
if check_adj:
missing_adj = True
continue
missing_main = True
continue
existing = set(os.listdir(folder_path))
for f in files:
if f not in existing:
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
missing_adj = True
elif not check_adj:
missing_main = True
return missing_adj, missing_main
# ---------- 3. 下载 7z 并解压 ----------
def download_and_extract(url, dst_dir, max_retries=3):
os.makedirs(dst_dir, exist_ok=True)
filename = os.path.basename(urlsplit(url).path) or "download.7z"
file_path = os.path.join(dst_dir, filename)
for attempt in range(1, max_retries+1):
try:
# 下载
with requests.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
total = int(r.headers.get("content-length",0))
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
for chunk in r.iter_content(8192):
f.write(chunk)
bar.update(len(chunk))
# 解压 7z
with py7zr.SevenZipFile(file_path, mode='r') as archive:
archive.extractall(path=dst_dir)
os.remove(file_path)
return
except Exception as e:
if attempt==max_retries: raise RuntimeError("下载或解压失败")
print("错误,重试中...", e)
# ---------- 4. 下载 Kaggle 数据 ----------
def download_kaggle_data(base_dir, dataset):
try:
path = kagglehub.dataset_download(dataset)
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
except Exception as e:
print("Kaggle 下载失败:", dataset, e)
# ---------- 5. 整理目录 ----------
def rearrange_dir():
data_dir = os.path.join(os.getcwd(), "data")
nested = os.path.join(data_dir,"data")
if os.path.isdir(nested):
for item in os.listdir(nested):
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
else:
shutil.copy2(src, dst)
shutil.rmtree(nested)
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
for f in os.listdir(data_dir):
if kw in f.lower() and f.endswith((".h5",".pkl")):
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
solar = os.path.join(data_dir,"solar-energy")
if os.path.isdir(solar):
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
csv = os.path.join(solar,"solar_AL.csv")
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
shutil.rmtree(solar)
# ---------- 6. 主流程 ----------
def check_and_download_data():
cwd = os.getcwd()
data_dir = os.path.join(cwd,"data")
expected = load_structure_json()
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
if missing_adj:
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
if not os.path.isdir(baq_folder):
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
if missing_main:
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
download_kaggle_data(cwd, ds)
rearrange_dir()
return True
if __name__=="__main__":
check_and_download_data()