import os, json, shutil, requests from urllib.parse import urlsplit from tqdm import tqdm import kagglehub import py7zr # ---------- 1. 加载结构 JSON ---------- def load_structure_json(path="utils/dataset.json"): with open(path, "r", encoding="utf-8") as f: return json.load(f) # ---------- 2. 检测完整性 ---------- def detect_data_integrity(data_dir, expected, check_adj=False): missing_adj, missing_main = False, False if not os.path.isdir(data_dir): return True, True for folder, files in expected.items(): folder_path = os.path.join(data_dir, folder) if not os.path.isdir(folder_path): if check_adj: missing_adj = True continue missing_main = True continue existing = set(os.listdir(folder_path)) for f in files: if f not in existing: if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): missing_adj = True elif not check_adj: missing_main = True return missing_adj, missing_main # ---------- 3. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): os.makedirs(dst_dir, exist_ok=True) filename = os.path.basename(urlsplit(url).path) or "download.7z" file_path = os.path.join(dst_dir, filename) for attempt in range(1, max_retries+1): try: # 下载 with requests.get(url, stream=True, timeout=30) as r: r.raise_for_status() total = int(r.headers.get("content-length",0)) with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar: for chunk in r.iter_content(8192): f.write(chunk) bar.update(len(chunk)) # 解压 7z with py7zr.SevenZipFile(file_path, mode='r') as archive: archive.extractall(path=dst_dir) os.remove(file_path) return except Exception as e: if attempt==max_retries: raise RuntimeError("下载或解压失败") print("错误,重试中...", e) # ---------- 4. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: path = kagglehub.dataset_download(dataset) shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) except Exception as e: print("Kaggle 下载失败:", dataset, e) # ---------- 5. 整理目录 ---------- def rearrange_dir(): data_dir = os.path.join(os.getcwd(), "data") nested = os.path.join(data_dir,"data") if os.path.isdir(nested): for item in os.listdir(nested): src,dst = os.path.join(nested,item), os.path.join(data_dir,item) if os.path.isdir(src): shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录 else: shutil.copy2(src, dst) shutil.rmtree(nested) for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]: dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True) for f in os.listdir(data_dir): if kw in f.lower() and f.endswith((".h5",".pkl")): shutil.move(os.path.join(data_dir,f), os.path.join(dst,f)) solar = os.path.join(data_dir,"solar-energy") if os.path.isdir(solar): dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True) csv = os.path.join(solar,"solar_AL.csv") if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv")) shutil.rmtree(solar) # ---------- 6. 主流程 ---------- def check_and_download_data(): cwd = os.getcwd() data_dir = os.path.join(cwd,"data") expected = load_structure_json() missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) if missing_adj: download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) baq_folder = os.path.join(data_dir,"BeijingAirQuality") if not os.path.isdir(baq_folder): download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) if missing_main: for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: download_kaggle_data(cwd, ds) rearrange_dir() return True if __name__=="__main__": check_and_download_data()