115 lines
4.5 KiB
Python
Executable File
115 lines
4.5 KiB
Python
Executable File
import os, json, shutil, requests
|
|
from urllib.parse import urlsplit
|
|
from tqdm import tqdm
|
|
import kagglehub
|
|
import py7zr
|
|
|
|
# ---------- 1. 加载结构 JSON ----------
|
|
def load_structure_json(path="utils/dataset.json"):
|
|
with open(path, "r", encoding="utf-8") as f:
|
|
return json.load(f)
|
|
|
|
# ---------- 2. 检测完整性 ----------
|
|
def detect_data_integrity(data_dir, expected, check_adj=False):
|
|
missing_adj, missing_main = False, False
|
|
if not os.path.isdir(data_dir): return True, True
|
|
for folder, files in expected.items():
|
|
folder_path = os.path.join(data_dir, folder)
|
|
if not os.path.isdir(folder_path):
|
|
if check_adj:
|
|
missing_adj = True
|
|
continue
|
|
missing_main = True
|
|
continue
|
|
existing = set(os.listdir(folder_path))
|
|
for f in files:
|
|
if f not in existing:
|
|
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
|
|
missing_adj = True
|
|
elif not check_adj:
|
|
missing_main = True
|
|
return missing_adj, missing_main
|
|
|
|
# ---------- 3. 下载 7z 并解压 ----------
|
|
def download_and_extract(url, dst_dir, max_retries=3):
|
|
os.makedirs(dst_dir, exist_ok=True)
|
|
filename = os.path.basename(urlsplit(url).path) or "download.7z"
|
|
file_path = os.path.join(dst_dir, filename)
|
|
for attempt in range(1, max_retries+1):
|
|
try:
|
|
# 下载
|
|
with requests.get(url, stream=True, timeout=30) as r:
|
|
r.raise_for_status()
|
|
total = int(r.headers.get("content-length",0))
|
|
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
|
|
for chunk in r.iter_content(8192):
|
|
f.write(chunk)
|
|
bar.update(len(chunk))
|
|
# 解压 7z
|
|
with py7zr.SevenZipFile(file_path, mode='r') as archive:
|
|
archive.extractall(path=dst_dir)
|
|
os.remove(file_path)
|
|
return
|
|
except Exception as e:
|
|
if attempt==max_retries: raise RuntimeError("下载或解压失败")
|
|
print("错误,重试中...", e)
|
|
|
|
# ---------- 4. 下载 Kaggle 数据 ----------
|
|
def download_kaggle_data(base_dir, dataset):
|
|
try:
|
|
path = kagglehub.dataset_download(dataset)
|
|
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
|
except Exception as e:
|
|
print("Kaggle 下载失败:", dataset, e)
|
|
|
|
# ---------- 5. 整理目录 ----------
|
|
def rearrange_dir():
|
|
data_dir = os.path.join(os.getcwd(), "data")
|
|
nested = os.path.join(data_dir,"data")
|
|
if os.path.isdir(nested):
|
|
for item in os.listdir(nested):
|
|
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
|
|
if os.path.isdir(src):
|
|
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
|
|
else:
|
|
shutil.copy2(src, dst)
|
|
shutil.rmtree(nested)
|
|
|
|
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
|
|
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
|
|
for f in os.listdir(data_dir):
|
|
if kw in f.lower() and f.endswith((".h5",".pkl")):
|
|
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
|
|
|
|
solar = os.path.join(data_dir,"solar-energy")
|
|
if os.path.isdir(solar):
|
|
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
|
|
csv = os.path.join(solar,"solar_AL.csv")
|
|
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
|
|
shutil.rmtree(solar)
|
|
|
|
# ---------- 6. 主流程 ----------
|
|
def check_and_download_data():
|
|
cwd = os.getcwd()
|
|
data_dir = os.path.join(cwd,"data")
|
|
expected = load_structure_json()
|
|
|
|
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
|
|
if missing_adj:
|
|
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
|
|
|
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
|
|
if not os.path.isdir(baq_folder):
|
|
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
|
|
|
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
|
|
if missing_main:
|
|
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
|
|
download_kaggle_data(cwd, ds)
|
|
|
|
rearrange_dir()
|
|
return True
|
|
|
|
if __name__=="__main__":
|
|
check_and_download_data()
|