更新数据集的检测与下载逻辑
This commit is contained in:
parent
d06cf4e0aa
commit
b7ea73bc92
|
|
@ -4,6 +4,13 @@
|
||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
|
{
|
||||||
|
"name": "Python Debugger: Current File",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "${file}",
|
||||||
|
"console": "integratedTerminal"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "STID_PEMS-BAY",
|
"name": "STID_PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
|
|
|
||||||
|
|
@ -10,25 +10,48 @@ def load_structure_json(path="utils/dataset.json"):
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
|
|
||||||
# ---------- 2. 检测完整性 ----------
|
# ---------- 2. 检测完整性 ----------
|
||||||
def detect_data_integrity(data_dir, expected, check_adj=False):
|
def detect_data_integrity(data_dir, expected):
|
||||||
missing_adj, missing_main = False, False
|
missing_list = []
|
||||||
if not os.path.isdir(data_dir): return True, True
|
if not os.path.isdir(data_dir):
|
||||||
|
# 如果数据目录不存在,则所有数据集都缺失
|
||||||
|
missing_list.extend(expected.keys())
|
||||||
|
# 标记adj也缺失
|
||||||
|
missing_list.append("adj")
|
||||||
|
return missing_list
|
||||||
|
|
||||||
|
# 检查adj相关文件(距离矩阵文件)
|
||||||
|
has_missing_adj = False
|
||||||
|
for folder, files in expected.items():
|
||||||
|
folder_path = os.path.join(data_dir, folder)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
existing = set(os.listdir(folder_path))
|
||||||
|
for f in files:
|
||||||
|
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||||
|
has_missing_adj = True
|
||||||
|
break
|
||||||
|
if has_missing_adj:
|
||||||
|
missing_list.append("adj")
|
||||||
|
|
||||||
|
# 检查数据集主文件
|
||||||
for folder, files in expected.items():
|
for folder, files in expected.items():
|
||||||
folder_path = os.path.join(data_dir, folder)
|
folder_path = os.path.join(data_dir, folder)
|
||||||
if not os.path.isdir(folder_path):
|
if not os.path.isdir(folder_path):
|
||||||
if check_adj:
|
missing_list.append(folder)
|
||||||
missing_adj = True
|
|
||||||
continue
|
|
||||||
missing_main = True
|
|
||||||
continue
|
continue
|
||||||
|
|
||||||
existing = set(os.listdir(folder_path))
|
existing = set(os.listdir(folder_path))
|
||||||
|
has_missing_file = False
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
if f not in existing:
|
# 跳过距离矩阵文件,已经在上面检查过了
|
||||||
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")):
|
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||||
missing_adj = True
|
has_missing_file = True
|
||||||
elif not check_adj:
|
|
||||||
missing_main = True
|
if has_missing_file and folder not in missing_list:
|
||||||
return missing_adj, missing_main
|
missing_list.append(folder)
|
||||||
|
|
||||||
|
# print(f"缺失数据集:{missing_list}")
|
||||||
|
return missing_list
|
||||||
|
|
||||||
# ---------- 3. 下载 7z 并解压 ----------
|
# ---------- 3. 下载 7z 并解压 ----------
|
||||||
def download_and_extract(url, dst_dir, max_retries=3):
|
def download_and_extract(url, dst_dir, max_retries=3):
|
||||||
|
|
@ -57,11 +80,34 @@ def download_and_extract(url, dst_dir, max_retries=3):
|
||||||
# ---------- 4. 下载 Kaggle 数据 ----------
|
# ---------- 4. 下载 Kaggle 数据 ----------
|
||||||
def download_kaggle_data(base_dir, dataset):
|
def download_kaggle_data(base_dir, dataset):
|
||||||
try:
|
try:
|
||||||
|
print(f"Downloading kaggle dataset : {dataset}")
|
||||||
path = kagglehub.dataset_download(dataset)
|
path = kagglehub.dataset_download(dataset)
|
||||||
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Kaggle 下载失败:", dataset, e)
|
print("Kaggle 下载失败:", dataset, e)
|
||||||
|
|
||||||
|
# ---------- 5. 下载 GitHub 数据 ----------
|
||||||
|
def download_github_data(file_path, save_dir):
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||||
|
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||||
|
response = requests.head(raw_url, allow_redirects=True)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
|
||||||
|
return
|
||||||
|
|
||||||
|
file_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
response = requests.get(raw_url, stream=True, allow_redirects=True)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
file_path_to_save = os.path.join(save_dir, file_name)
|
||||||
|
with open(file_path_to_save, 'wb') as f:
|
||||||
|
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
|
||||||
# ---------- 5. 整理目录 ----------
|
# ---------- 5. 整理目录 ----------
|
||||||
def rearrange_dir():
|
def rearrange_dir():
|
||||||
data_dir = os.path.join(os.getcwd(), "data")
|
data_dir = os.path.join(os.getcwd(), "data")
|
||||||
|
|
@ -92,21 +138,57 @@ def rearrange_dir():
|
||||||
def check_and_download_data():
|
def check_and_download_data():
|
||||||
cwd = os.getcwd()
|
cwd = os.getcwd()
|
||||||
data_dir = os.path.join(cwd,"data")
|
data_dir = os.path.join(cwd,"data")
|
||||||
expected = load_structure_json()
|
file_tree = load_structure_json()
|
||||||
|
|
||||||
missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True)
|
# 执行一次检测,获取所有缺失项
|
||||||
if missing_adj:
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
|
print(f"缺失数据集:{missing_list}")
|
||||||
|
|
||||||
|
# 检查并下载adj数据
|
||||||
|
if "adj" in missing_list:
|
||||||
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
||||||
|
# 下载后从缺失列表中移除adj
|
||||||
|
missing_list.remove("adj")
|
||||||
|
|
||||||
baq_folder = os.path.join(data_dir,"BeijingAirQuality")
|
# 检查BeijingAirQuality和AirQuality
|
||||||
baq_folder2 = os.path.join(data_dir,"AirQuality")
|
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
|
||||||
if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2):
|
|
||||||
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
||||||
|
# 下载后更新缺失列表
|
||||||
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
|
|
||||||
_,missing_main = detect_data_integrity(data_dir, expected, check_adj=False)
|
# 检查并下载TaxiBJ数据
|
||||||
if missing_main:
|
if "TaxiBJ" in missing_list:
|
||||||
for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]:
|
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
|
||||||
download_kaggle_data(cwd, ds)
|
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
|
||||||
|
for file in taxibj_files:
|
||||||
|
file_path = f"Datasets/TaxiBJ/{file}"
|
||||||
|
download_github_data(file_path, taxi_bj_floder)
|
||||||
|
# 下载后更新缺失列表
|
||||||
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
|
|
||||||
|
# 检查并下载pems, bay, metr-la, solar-energy数据
|
||||||
|
# 定义数据集名称到Kaggle数据集的映射
|
||||||
|
kaggle_map = {
|
||||||
|
"PEMS03": "elmahy/pems-dataset",
|
||||||
|
"PEMS04": "elmahy/pems-dataset",
|
||||||
|
"PEMS07": "elmahy/pems-dataset",
|
||||||
|
"PEMS08": "elmahy/pems-dataset",
|
||||||
|
"PEMS-BAY": "scchuy/pemsbay",
|
||||||
|
"METR-LA": "annnnguyen/metr-la-dataset",
|
||||||
|
"SolarEnergy": "wangshaoqi/solar-energy"
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查是否有需要从Kaggle下载的数据集
|
||||||
|
# 先对kaggle下载地址进行去重,避免重复下载相同的数据集
|
||||||
|
downloaded_kaggle_datasets = set()
|
||||||
|
|
||||||
|
for dataset, kaggle_ds in kaggle_map.items():
|
||||||
|
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
|
||||||
|
download_kaggle_data(cwd, kaggle_ds)
|
||||||
|
# 将已下载的数据集添加到集合中
|
||||||
|
downloaded_kaggle_datasets.add(kaggle_ds)
|
||||||
|
# 下载一个数据集后更新缺失列表
|
||||||
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
|
|
||||||
rearrange_dir()
|
rearrange_dir()
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -36,5 +36,6 @@
|
||||||
"SolarEnergy.csv"
|
"SolarEnergy.csv"
|
||||||
],
|
],
|
||||||
"BeijingAirQuality": ["data.dat", "desc.json"],
|
"BeijingAirQuality": ["data.dat", "desc.json"],
|
||||||
"AirQuality": ["data.dat"]
|
"AirQuality": ["data.dat"],
|
||||||
|
"BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue