更新数据集的检测与下载逻辑
This commit is contained in:
parent
d06cf4e0aa
commit
b7ea73bc92
|
|
@ -4,7 +4,14 @@
|
|||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
},
|
||||
{
|
||||
"name": "STID_PEMS-BAY",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
|
|
|
|||
|
|
@ -10,25 +10,48 @@ def load_structure_json(path="utils/dataset.json"):
|
|||
return json.load(f)
|
||||
|
||||
# ---------- 2. 检测完整性 ----------
|
||||
def detect_data_integrity(data_dir, expected, check_adj=False):
|
||||
missing_adj, missing_main = False, False
|
||||
if not os.path.isdir(data_dir): return True, True
|
||||
def detect_data_integrity(data_dir, expected):
|
||||
missing_list = []
|
||||
if not os.path.isdir(data_dir):
|
||||
# 如果数据目录不存在,则所有数据集都缺失
|
||||
missing_list.extend(expected.keys())
|
||||
# 标记adj也缺失
|
||||
missing_list.append("adj")
|
||||
return missing_list
|
||||
|
||||
# 检查adj相关文件(距离矩阵文件)
|
||||
has_missing_adj = False
|
||||
for folder, files in expected.items():
|
||||
folder_path = os.path.join(data_dir, folder)
|
||||
if os.path.isdir(folder_path):
|
||||
existing = set(os.listdir(folder_path))
|
||||
for f in files:
|
||||
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||
has_missing_adj = True
|
||||
break
|
||||
if has_missing_adj:
|
||||
missing_list.append("adj")
|
||||
|
||||
# 检查数据集主文件
|
||||
for folder, files in expected.items():
|
||||
folder_path = os.path.join(data_dir, folder)
|
||||
if not os.path.isdir(folder_path):
|
||||
if check_adj:
|
||||
missing_adj = True
|
||||
continue
|
||||
missing_main = True
|
||||
missing_list.append(folder)
|
||||
continue
|
||||
|
||||
existing = set(os.listdir(folder_path))
|
||||
has_missing_file = False
|
||||
|
||||
for f in files:
|
||||
if f not in existing:
|
||||
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
|
||||
missing_adj = True
|
||||
elif not check_adj:
|
||||
missing_main = True
|
||||
return missing_adj, missing_main
|
||||
# 跳过距离矩阵文件,已经在上面检查过了
|
||||
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||
has_missing_file = True
|
||||
|
||||
if has_missing_file and folder not in missing_list:
|
||||
missing_list.append(folder)
|
||||
|
||||
# print(f"缺失数据集:{missing_list}")
|
||||
return missing_list
|
||||
|
||||
# ---------- 3. 下载 7z 并解压 ----------
|
||||
def download_and_extract(url, dst_dir, max_retries=3):
|
||||
|
|
@ -57,11 +80,34 @@ def download_and_extract(url, dst_dir, max_retries=3):
|
|||
# ---------- 4. 下载 Kaggle 数据 ----------
|
||||
def download_kaggle_data(base_dir, dataset):
|
||||
try:
|
||||
print(f"Downloading kaggle dataset : {dataset}")
|
||||
path = kagglehub.dataset_download(dataset)
|
||||
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
||||
except Exception as e:
|
||||
print("Kaggle 下载失败:", dataset, e)
|
||||
|
||||
# ---------- 5. 下载 GitHub 数据 ----------
|
||||
def download_github_data(file_path, save_dir):
|
||||
if not os.path.exists(save_dir):
|
||||
os.makedirs(save_dir)
|
||||
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||
response = requests.head(raw_url, allow_redirects=True)
|
||||
if response.status_code != 200:
|
||||
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
|
||||
return
|
||||
|
||||
file_size = int(response.headers.get('Content-Length', 0))
|
||||
response = requests.get(raw_url, stream=True, allow_redirects=True)
|
||||
file_name = os.path.basename(file_path)
|
||||
file_path_to_save = os.path.join(save_dir, file_name)
|
||||
with open(file_path_to_save, 'wb') as f:
|
||||
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
|
||||
for chunk in response.iter_content(chunk_size=1024):
|
||||
if chunk:
|
||||
f.write(chunk)
|
||||
pbar.update(len(chunk))
|
||||
|
||||
# ---------- 5. 整理目录 ----------
|
||||
def rearrange_dir():
|
||||
data_dir = os.path.join(os.getcwd(), "data")
|
||||
|
|
@ -92,21 +138,57 @@ def rearrange_dir():
|
|||
def check_and_download_data():
|
||||
cwd = os.getcwd()
|
||||
data_dir = os.path.join(cwd,"data")
|
||||
expected = load_structure_json()
|
||||
file_tree = load_structure_json()
|
||||
|
||||
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
|
||||
if missing_adj:
|
||||
# 执行一次检测,获取所有缺失项
|
||||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||
print(f"缺失数据集:{missing_list}")
|
||||
|
||||
# 检查并下载adj数据
|
||||
if "adj" in missing_list:
|
||||
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
||||
# 下载后从缺失列表中移除adj
|
||||
missing_list.remove("adj")
|
||||
|
||||
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
|
||||
baq_folder2 = os.path.join(data_dir,"AirQuality")
|
||||
if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2):
|
||||
# 检查BeijingAirQuality和AirQuality
|
||||
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
|
||||
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
||||
# 下载后更新缺失列表
|
||||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||
|
||||
# 检查并下载TaxiBJ数据
|
||||
if "TaxiBJ" in missing_list:
|
||||
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
|
||||
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
|
||||
for file in taxibj_files:
|
||||
file_path = f"Datasets/TaxiBJ/{file}"
|
||||
download_github_data(file_path, taxi_bj_floder)
|
||||
# 下载后更新缺失列表
|
||||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||
|
||||
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
|
||||
if missing_main:
|
||||
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
|
||||
download_kaggle_data(cwd, ds)
|
||||
# 检查并下载pems, bay, metr-la, solar-energy数据
|
||||
# 定义数据集名称到Kaggle数据集的映射
|
||||
kaggle_map = {
|
||||
"PEMS03": "elmahy/pems-dataset",
|
||||
"PEMS04": "elmahy/pems-dataset",
|
||||
"PEMS07": "elmahy/pems-dataset",
|
||||
"PEMS08": "elmahy/pems-dataset",
|
||||
"PEMS-BAY": "scchuy/pemsbay",
|
||||
"METR-LA": "annnnguyen/metr-la-dataset",
|
||||
"SolarEnergy": "wangshaoqi/solar-energy"
|
||||
}
|
||||
|
||||
# 检查是否有需要从Kaggle下载的数据集
|
||||
# 先对kaggle下载地址进行去重,避免重复下载相同的数据集
|
||||
downloaded_kaggle_datasets = set()
|
||||
|
||||
for dataset, kaggle_ds in kaggle_map.items():
|
||||
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
|
||||
download_kaggle_data(cwd, kaggle_ds)
|
||||
# 将已下载的数据集添加到集合中
|
||||
downloaded_kaggle_datasets.add(kaggle_ds)
|
||||
# 下载一个数据集后更新缺失列表
|
||||
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||
|
||||
rearrange_dir()
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -36,5 +36,6 @@
|
|||
"SolarEnergy.csv"
|
||||
],
|
||||
"BeijingAirQuality": ["data.dat", "desc.json"],
|
||||
"AirQuality": ["data.dat"]
|
||||
"AirQuality": ["data.dat"],
|
||||
"BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue