更新数据集的检测与下载逻辑

This commit is contained in:
czzhangheng 2025-11-23 17:58:44 +08:00
parent d06cf4e0aa
commit b7ea73bc92
3 changed files with 115 additions and 25 deletions

7
.vscode/launch.json vendored
View File

@ -4,6 +4,13 @@
// 访: https://go.microsoft.com/fwlink/?linkid=830387 // 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0", "version": "0.2.0",
"configurations": [ "configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
},
{ {
"name": "STID_PEMS-BAY", "name": "STID_PEMS-BAY",
"type": "debugpy", "type": "debugpy",

View File

@ -10,25 +10,48 @@ def load_structure_json(path="utils/dataset.json"):
return json.load(f) return json.load(f)
# ---------- 2. 检测完整性 ---------- # ---------- 2. 检测完整性 ----------
def detect_data_integrity(data_dir, expected, check_adj=False): def detect_data_integrity(data_dir, expected):
missing_adj, missing_main = False, False missing_list = []
if not os.path.isdir(data_dir): return True, True if not os.path.isdir(data_dir):
# 如果数据目录不存在,则所有数据集都缺失
missing_list.extend(expected.keys())
# 标记adj也缺失
missing_list.append("adj")
return missing_list
# 检查adj相关文件距离矩阵文件
has_missing_adj = False
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if os.path.isdir(folder_path):
existing = set(os.listdir(folder_path))
for f in files:
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
has_missing_adj = True
break
if has_missing_adj:
missing_list.append("adj")
# 检查数据集主文件
for folder, files in expected.items(): for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder) folder_path = os.path.join(data_dir, folder)
if not os.path.isdir(folder_path): if not os.path.isdir(folder_path):
if check_adj: missing_list.append(folder)
missing_adj = True
continue
missing_main = True
continue continue
existing = set(os.listdir(folder_path)) existing = set(os.listdir(folder_path))
has_missing_file = False
for f in files: for f in files:
if f not in existing: # 跳过距离矩阵文件,已经在上面检查过了
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
missing_adj = True has_missing_file = True
elif not check_adj:
missing_main = True if has_missing_file and folder not in missing_list:
return missing_adj, missing_main missing_list.append(folder)
# print(f"缺失数据集:{missing_list}")
return missing_list
# ---------- 3. 下载 7z 并解压 ---------- # ---------- 3. 下载 7z 并解压 ----------
def download_and_extract(url, dst_dir, max_retries=3): def download_and_extract(url, dst_dir, max_retries=3):
@ -57,11 +80,34 @@ def download_and_extract(url, dst_dir, max_retries=3):
# ---------- 4. 下载 Kaggle 数据 ---------- # ---------- 4. 下载 Kaggle 数据 ----------
def download_kaggle_data(base_dir, dataset): def download_kaggle_data(base_dir, dataset):
try: try:
print(f"Downloading kaggle dataset : {dataset}")
path = kagglehub.dataset_download(dataset) path = kagglehub.dataset_download(dataset)
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
except Exception as e: except Exception as e:
print("Kaggle 下载失败:", dataset, e) print("Kaggle 下载失败:", dataset, e)
# ---------- 5. 下载 GitHub 数据 ----------
def download_github_data(file_path, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
response = requests.head(raw_url, allow_redirects=True)
if response.status_code != 200:
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
return
file_size = int(response.headers.get('Content-Length', 0))
response = requests.get(raw_url, stream=True, allow_redirects=True)
file_name = os.path.basename(file_path)
file_path_to_save = os.path.join(save_dir, file_name)
with open(file_path_to_save, 'wb') as f:
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
# ---------- 5. 整理目录 ---------- # ---------- 5. 整理目录 ----------
def rearrange_dir(): def rearrange_dir():
data_dir = os.path.join(os.getcwd(), "data") data_dir = os.path.join(os.getcwd(), "data")
@ -92,21 +138,57 @@ def rearrange_dir():
def check_and_download_data(): def check_and_download_data():
cwd = os.getcwd() cwd = os.getcwd()
data_dir = os.path.join(cwd,"data") data_dir = os.path.join(cwd,"data")
expected = load_structure_json() file_tree = load_structure_json()
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) # 执行一次检测,获取所有缺失项
if missing_adj: missing_list = detect_data_integrity(data_dir, file_tree)
print(f"缺失数据集:{missing_list}")
# 检查并下载adj数据
if "adj" in missing_list:
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
# 下载后从缺失列表中移除adj
missing_list.remove("adj")
baq_folder = os.path.join(data_dir,"BeijingAirQuality") # 检查BeijingAirQuality和AirQuality
baq_folder2 = os.path.join(data_dir,"AirQuality") if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2):
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) # 检查并下载TaxiBJ数据
if missing_main: if "TaxiBJ" in missing_list:
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
download_kaggle_data(cwd, ds) taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
for file in taxibj_files:
file_path = f"Datasets/TaxiBJ/{file}"
download_github_data(file_path, taxi_bj_floder)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载pems, bay, metr-la, solar-energy数据
# 定义数据集名称到Kaggle数据集的映射
kaggle_map = {
"PEMS03": "elmahy/pems-dataset",
"PEMS04": "elmahy/pems-dataset",
"PEMS07": "elmahy/pems-dataset",
"PEMS08": "elmahy/pems-dataset",
"PEMS-BAY": "scchuy/pemsbay",
"METR-LA": "annnnguyen/metr-la-dataset",
"SolarEnergy": "wangshaoqi/solar-energy"
}
# 检查是否有需要从Kaggle下载的数据集
# 先对kaggle下载地址进行去重避免重复下载相同的数据集
downloaded_kaggle_datasets = set()
for dataset, kaggle_ds in kaggle_map.items():
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
download_kaggle_data(cwd, kaggle_ds)
# 将已下载的数据集添加到集合中
downloaded_kaggle_datasets.add(kaggle_ds)
# 下载一个数据集后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
rearrange_dir() rearrange_dir()
return True return True

View File

@ -36,5 +36,6 @@
"SolarEnergy.csv" "SolarEnergy.csv"
], ],
"BeijingAirQuality": ["data.dat", "desc.json"], "BeijingAirQuality": ["data.dat", "desc.json"],
"AirQuality": ["data.dat"] "AirQuality": ["data.dat"],
"BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
} }