TrafficWheel/utils/Download_data.py

198 lines
8.1 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os, json, shutil, requests
from urllib.parse import urlsplit
from tqdm import tqdm
import kagglehub
import py7zr
# ---------- 1. 加载结构 JSON ----------
def load_structure_json(path="utils/dataset.json"):
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
# ---------- 2. 检测完整性 ----------
def detect_data_integrity(data_dir, expected):
missing_list = []
if not os.path.isdir(data_dir):
# 如果数据目录不存在,则所有数据集都缺失
missing_list.extend(expected.keys())
# 标记adj也缺失
missing_list.append("adj")
return missing_list
# 检查adj相关文件距离矩阵文件
has_missing_adj = False
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if os.path.isdir(folder_path):
existing = set(os.listdir(folder_path))
for f in files:
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
has_missing_adj = True
break
if has_missing_adj:
missing_list.append("adj")
# 检查数据集主文件
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if not os.path.isdir(folder_path):
missing_list.append(folder)
continue
existing = set(os.listdir(folder_path))
has_missing_file = False
for f in files:
# 跳过距离矩阵文件,已经在上面检查过了
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
has_missing_file = True
if has_missing_file and folder not in missing_list:
missing_list.append(folder)
# print(f"缺失数据集:{missing_list}")
return missing_list
# ---------- 3. 下载 7z 并解压 ----------
def download_and_extract(url, dst_dir, max_retries=3):
os.makedirs(dst_dir, exist_ok=True)
filename = os.path.basename(urlsplit(url).path) or "download.7z"
file_path = os.path.join(dst_dir, filename)
for attempt in range(1, max_retries+1):
try:
# 下载
with requests.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
total = int(r.headers.get("content-length",0))
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
for chunk in r.iter_content(8192):
f.write(chunk)
bar.update(len(chunk))
# 解压 7z
with py7zr.SevenZipFile(file_path, mode='r') as archive:
archive.extractall(path=dst_dir)
os.remove(file_path)
return
except Exception as e:
if attempt==max_retries: raise RuntimeError("下载或解压失败")
print("错误,重试中...", e)
# ---------- 4. 下载 Kaggle 数据 ----------
def download_kaggle_data(base_dir, dataset):
try:
print(f"Downloading kaggle dataset : {dataset}")
path = kagglehub.dataset_download(dataset)
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
except Exception as e:
print("Kaggle 下载失败:", dataset, e)
# ---------- 5. 下载 GitHub 数据 ----------
def download_github_data(file_path, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
response = requests.head(raw_url, allow_redirects=True)
if response.status_code != 200:
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
return
file_size = int(response.headers.get('Content-Length', 0))
response = requests.get(raw_url, stream=True, allow_redirects=True)
file_name = os.path.basename(file_path)
file_path_to_save = os.path.join(save_dir, file_name)
with open(file_path_to_save, 'wb') as f:
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
# ---------- 5. 整理目录 ----------
def rearrange_dir():
data_dir = os.path.join(os.getcwd(), "data")
nested = os.path.join(data_dir,"data")
if os.path.isdir(nested):
for item in os.listdir(nested):
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
else:
shutil.copy2(src, dst)
shutil.rmtree(nested)
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
for f in os.listdir(data_dir):
if kw in f.lower() and f.endswith((".h5",".pkl")):
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
solar = os.path.join(data_dir,"solar-energy")
if os.path.isdir(solar):
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
csv = os.path.join(solar,"solar_AL.csv")
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
shutil.rmtree(solar)
# ---------- 6. 主流程 ----------
def check_and_download_data():
cwd = os.getcwd()
data_dir = os.path.join(cwd,"data")
file_tree = load_structure_json()
# 执行一次检测,获取所有缺失项
missing_list = detect_data_integrity(data_dir, file_tree)
print(f"缺失数据集:{missing_list}")
# 检查并下载adj数据
if "adj" in missing_list:
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
# 下载后从缺失列表中移除adj
missing_list.remove("adj")
# 检查BeijingAirQuality和AirQuality
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载TaxiBJ数据
if "TaxiBJ" in missing_list:
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
for file in taxibj_files:
file_path = f"Datasets/TaxiBJ/{file}"
download_github_data(file_path, taxi_bj_floder)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载pems, bay, metr-la, solar-energy数据
# 定义数据集名称到Kaggle数据集的映射
kaggle_map = {
"PEMS03": "elmahy/pems-dataset",
"PEMS04": "elmahy/pems-dataset",
"PEMS07": "elmahy/pems-dataset",
"PEMS08": "elmahy/pems-dataset",
"PEMS-BAY": "scchuy/pemsbay",
"METR-LA": "annnnguyen/metr-la-dataset",
"SolarEnergy": "wangshaoqi/solar-energy"
}
# 检查是否有需要从Kaggle下载的数据集
# 先对kaggle下载地址进行去重避免重复下载相同的数据集
downloaded_kaggle_datasets = set()
for dataset, kaggle_ds in kaggle_map.items():
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
download_kaggle_data(cwd, kaggle_ds)
# 将已下载的数据集添加到集合中
downloaded_kaggle_datasets.add(kaggle_ds)
# 下载一个数据集后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
rearrange_dir()
return True
if __name__=="__main__":
check_and_download_data()