import os, json, shutil, requests from urllib.parse import urlsplit from tqdm import tqdm import kagglehub import py7zr # ---------- 1. 检测完整性 ---------- def detect_data_integrity(data_dir, expected): missing_list = [] if not os.path.isdir(data_dir): # 如果数据目录不存在,则所有数据集都缺失 missing_list.extend(expected.keys()) # 标记adj也缺失 missing_list.append("adj") return missing_list # 检查adj相关文件(距离矩阵文件) has_missing_adj = False for folder, files in expected.items(): folder_path = os.path.join(data_dir, folder) if os.path.isdir(folder_path): existing = set(os.listdir(folder_path)) for f in files: if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: has_missing_adj = True break if has_missing_adj: missing_list.append("adj") # 检查数据集主文件 for folder, files in expected.items(): folder_path = os.path.join(data_dir, folder) if not os.path.isdir(folder_path): missing_list.append(folder) continue existing = set(os.listdir(folder_path)) has_missing_file = False for f in files: # 跳过距离矩阵文件,已经在上面检查过了 if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: has_missing_file = True if has_missing_file and folder not in missing_list: missing_list.append(folder) # print(f"缺失数据集:{missing_list}") return missing_list # ---------- 2. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): os.makedirs(dst_dir, exist_ok=True) filename = os.path.basename(urlsplit(url).path) or "download.7z" file_path = os.path.join(dst_dir, filename) for attempt in range(1, max_retries+1): try: # 下载 with requests.get(url, stream=True, timeout=30) as r: r.raise_for_status() total = int(r.headers.get("content-length",0)) with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar: for chunk in r.iter_content(8192): f.write(chunk) bar.update(len(chunk)) # 解压 7z with py7zr.SevenZipFile(file_path, mode='r') as archive: archive.extractall(path=dst_dir) os.remove(file_path) return except Exception as e: if attempt==max_retries: raise RuntimeError("下载或解压失败") print("错误,重试中...", e) # ---------- 3. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: print(f"Downloading kaggle dataset : {dataset}") path = kagglehub.dataset_download(dataset) shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) except Exception as e: print("Kaggle 下载失败:", dataset, e) # ---------- 4. 下载 GitHub 数据 ---------- def download_github_data(file_path, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" # raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" response = requests.head(raw_url, allow_redirects=True) if response.status_code != 200: print(f"Failed to get file size for {raw_url}. Status code:", response.status_code) return file_size = int(response.headers.get('Content-Length', 0)) response = requests.get(raw_url, stream=True, allow_redirects=True) file_name = os.path.basename(file_path) file_path_to_save = os.path.join(save_dir, file_name) with open(file_path_to_save, 'wb') as f: with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar: for chunk in response.iter_content(chunk_size=1024): if chunk: f.write(chunk) pbar.update(len(chunk)) # ---------- 5. 整理目录 ---------- def rearrange_dir(): data_dir = os.path.join(os.getcwd(), "data") nested = os.path.join(data_dir,"data") if os.path.isdir(nested): for item in os.listdir(nested): src,dst = os.path.join(nested,item), os.path.join(data_dir,item) if os.path.isdir(src): shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录 else: shutil.copy2(src, dst) shutil.rmtree(nested) for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]: dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True) for f in os.listdir(data_dir): if kw in f.lower() and f.endswith((".h5",".pkl")): shutil.move(os.path.join(data_dir,f), os.path.join(dst,f)) solar = os.path.join(data_dir,"solar-energy") if os.path.isdir(solar): dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True) csv = os.path.join(solar,"solar_AL.csv") if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv")) shutil.rmtree(solar) # ---------- 6. 主流程 ---------- def check_and_download_data(): # 加载结构文件,检测缺失数据集 cwd = os.getcwd() data_dir = os.path.join(cwd,"data") with open("utils/dataset.json", "r", encoding="utf-8") as f: file_tree = json.load(f) missing_list = detect_data_integrity(data_dir, file_tree) # print(f"缺失数据集:{missing_list}") # 检查并下载adj数据 if "adj" in missing_list: download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) # 下载后从缺失列表中移除adj missing_list.remove("adj") # 检查BeijingAirQuality和AirQuality if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list: download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) # 下载后更新缺失列表 missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载TaxiBJ数据 if "BeijingTaxi" in missing_list: taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi") taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy'] for file in taxibj_files: file_path = f"Datasets/TaxiBJ/{file}" download_github_data(file_path, taxi_bj_floder) # 下载后更新缺失列表 missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载TaxiBJ数据 if "NYCBike" in missing_list: nycbike_bj_floder = os.path.join(data_dir, "NYCBike") download_and_extract("http://code.zhang-heng.com/static/NYCBike.7z", data_dir) # 下载后更新缺失列表 missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载pems, bay, metr-la, solar-energy数据 kaggle_map = { "PEMS03": "elmahy/pems-dataset", "PEMS04": "elmahy/pems-dataset", "PEMS07": "elmahy/pems-dataset", "PEMS08": "elmahy/pems-dataset", "PEMS-BAY": "scchuy/pemsbay", "METR-LA": "annnnguyen/metr-la-dataset", "SolarEnergy": "wangshaoqi/solar-energy" } # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 downloaded_kaggle_datasets = set() for dataset, kaggle_ds in kaggle_map.items(): if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets: download_kaggle_data(cwd, kaggle_ds) # 将已下载的数据集添加到集合中 downloaded_kaggle_datasets.add(kaggle_ds) # 下载一个数据集后更新缺失列表 missing_list = detect_data_integrity(data_dir, file_tree) rearrange_dir() return True if __name__=="__main__": check_and_download_data()