198 lines
8.1 KiB
Python
Executable File
198 lines
8.1 KiB
Python
Executable File
import os, json, shutil, requests
|
||
from urllib.parse import urlsplit
|
||
from tqdm import tqdm
|
||
import kagglehub
|
||
import py7zr
|
||
|
||
# ---------- 1. 加载结构 JSON ----------
|
||
def load_structure_json(path="utils/dataset.json"):
|
||
with open(path, "r", encoding="utf-8") as f:
|
||
return json.load(f)
|
||
|
||
# ---------- 2. 检测完整性 ----------
|
||
def detect_data_integrity(data_dir, expected):
|
||
missing_list = []
|
||
if not os.path.isdir(data_dir):
|
||
# 如果数据目录不存在,则所有数据集都缺失
|
||
missing_list.extend(expected.keys())
|
||
# 标记adj也缺失
|
||
missing_list.append("adj")
|
||
return missing_list
|
||
|
||
# 检查adj相关文件(距离矩阵文件)
|
||
has_missing_adj = False
|
||
for folder, files in expected.items():
|
||
folder_path = os.path.join(data_dir, folder)
|
||
if os.path.isdir(folder_path):
|
||
existing = set(os.listdir(folder_path))
|
||
for f in files:
|
||
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||
has_missing_adj = True
|
||
break
|
||
if has_missing_adj:
|
||
missing_list.append("adj")
|
||
|
||
# 检查数据集主文件
|
||
for folder, files in expected.items():
|
||
folder_path = os.path.join(data_dir, folder)
|
||
if not os.path.isdir(folder_path):
|
||
missing_list.append(folder)
|
||
continue
|
||
|
||
existing = set(os.listdir(folder_path))
|
||
has_missing_file = False
|
||
|
||
for f in files:
|
||
# 跳过距离矩阵文件,已经在上面检查过了
|
||
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||
has_missing_file = True
|
||
|
||
if has_missing_file and folder not in missing_list:
|
||
missing_list.append(folder)
|
||
|
||
# print(f"缺失数据集:{missing_list}")
|
||
return missing_list
|
||
|
||
# ---------- 3. 下载 7z 并解压 ----------
|
||
def download_and_extract(url, dst_dir, max_retries=3):
|
||
os.makedirs(dst_dir, exist_ok=True)
|
||
filename = os.path.basename(urlsplit(url).path) or "download.7z"
|
||
file_path = os.path.join(dst_dir, filename)
|
||
for attempt in range(1, max_retries+1):
|
||
try:
|
||
# 下载
|
||
with requests.get(url, stream=True, timeout=30) as r:
|
||
r.raise_for_status()
|
||
total = int(r.headers.get("content-length",0))
|
||
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
|
||
for chunk in r.iter_content(8192):
|
||
f.write(chunk)
|
||
bar.update(len(chunk))
|
||
# 解压 7z
|
||
with py7zr.SevenZipFile(file_path, mode='r') as archive:
|
||
archive.extractall(path=dst_dir)
|
||
os.remove(file_path)
|
||
return
|
||
except Exception as e:
|
||
if attempt==max_retries: raise RuntimeError("下载或解压失败")
|
||
print("错误,重试中...", e)
|
||
|
||
# ---------- 4. 下载 Kaggle 数据 ----------
|
||
def download_kaggle_data(base_dir, dataset):
|
||
try:
|
||
print(f"Downloading kaggle dataset : {dataset}")
|
||
path = kagglehub.dataset_download(dataset)
|
||
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
||
except Exception as e:
|
||
print("Kaggle 下载失败:", dataset, e)
|
||
|
||
# ---------- 5. 下载 GitHub 数据 ----------
|
||
def download_github_data(file_path, save_dir):
|
||
if not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||
response = requests.head(raw_url, allow_redirects=True)
|
||
if response.status_code != 200:
|
||
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
|
||
return
|
||
|
||
file_size = int(response.headers.get('Content-Length', 0))
|
||
response = requests.get(raw_url, stream=True, allow_redirects=True)
|
||
file_name = os.path.basename(file_path)
|
||
file_path_to_save = os.path.join(save_dir, file_name)
|
||
with open(file_path_to_save, 'wb') as f:
|
||
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
|
||
for chunk in response.iter_content(chunk_size=1024):
|
||
if chunk:
|
||
f.write(chunk)
|
||
pbar.update(len(chunk))
|
||
|
||
# ---------- 5. 整理目录 ----------
|
||
def rearrange_dir():
|
||
data_dir = os.path.join(os.getcwd(), "data")
|
||
nested = os.path.join(data_dir,"data")
|
||
if os.path.isdir(nested):
|
||
for item in os.listdir(nested):
|
||
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
|
||
if os.path.isdir(src):
|
||
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
|
||
else:
|
||
shutil.copy2(src, dst)
|
||
shutil.rmtree(nested)
|
||
|
||
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
|
||
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
|
||
for f in os.listdir(data_dir):
|
||
if kw in f.lower() and f.endswith((".h5",".pkl")):
|
||
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
|
||
|
||
solar = os.path.join(data_dir,"solar-energy")
|
||
if os.path.isdir(solar):
|
||
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
|
||
csv = os.path.join(solar,"solar_AL.csv")
|
||
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
|
||
shutil.rmtree(solar)
|
||
|
||
# ---------- 6. 主流程 ----------
|
||
def check_and_download_data():
|
||
cwd = os.getcwd()
|
||
data_dir = os.path.join(cwd,"data")
|
||
file_tree = load_structure_json()
|
||
|
||
# 执行一次检测,获取所有缺失项
|
||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||
print(f"缺失数据集:{missing_list}")
|
||
|
||
# 检查并下载adj数据
|
||
if "adj" in missing_list:
|
||
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
||
# 下载后从缺失列表中移除adj
|
||
missing_list.remove("adj")
|
||
|
||
# 检查BeijingAirQuality和AirQuality
|
||
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
|
||
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
||
# 下载后更新缺失列表
|
||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||
|
||
# 检查并下载TaxiBJ数据
|
||
if "TaxiBJ" in missing_list:
|
||
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
|
||
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
|
||
for file in taxibj_files:
|
||
file_path = f"Datasets/TaxiBJ/{file}"
|
||
download_github_data(file_path, taxi_bj_floder)
|
||
# 下载后更新缺失列表
|
||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||
|
||
# 检查并下载pems, bay, metr-la, solar-energy数据
|
||
# 定义数据集名称到Kaggle数据集的映射
|
||
kaggle_map = {
|
||
"PEMS03": "elmahy/pems-dataset",
|
||
"PEMS04": "elmahy/pems-dataset",
|
||
"PEMS07": "elmahy/pems-dataset",
|
||
"PEMS08": "elmahy/pems-dataset",
|
||
"PEMS-BAY": "scchuy/pemsbay",
|
||
"METR-LA": "annnnguyen/metr-la-dataset",
|
||
"SolarEnergy": "wangshaoqi/solar-energy"
|
||
}
|
||
|
||
# 检查是否有需要从Kaggle下载的数据集
|
||
# 先对kaggle下载地址进行去重,避免重复下载相同的数据集
|
||
downloaded_kaggle_datasets = set()
|
||
|
||
for dataset, kaggle_ds in kaggle_map.items():
|
||
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
|
||
download_kaggle_data(cwd, kaggle_ds)
|
||
# 将已下载的数据集添加到集合中
|
||
downloaded_kaggle_datasets.add(kaggle_ds)
|
||
# 下载一个数据集后更新缺失列表
|
||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||
|
||
rearrange_dir()
|
||
return True
|
||
|
||
if __name__=="__main__":
|
||
check_and_download_data()
|