From 42e250d8e12616b4064a95bbb2dfe8ab8a5b2b08 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 4 Sep 2025 11:20:17 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0requirements.txt=E4=BE=9D?= =?UTF-8?q?=E8=B5=96=E6=96=87=E4=BB=B6=EF=BC=8C=E6=9B=B4=E6=96=B0README?= =?UTF-8?q?=E5=92=8Cutils/download.py=EF=BC=8C=E7=A7=BB=E9=99=A4STDEN?= =?UTF-8?q?=E5=AD=90=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 13 ++++- STDEN | 1 - requirements.txt | 42 +++++++++++++ utils/download.py | 146 ++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 200 insertions(+), 2 deletions(-) delete mode 160000 STDEN create mode 100644 requirements.txt create mode 100644 utils/download.py diff --git a/README.md b/README.md index e678e6d..8302f7d 100644 --- a/README.md +++ b/README.md @@ -4,6 +4,17 @@ Secret Projct mkdir -p models/gpt2 +## Download dataset +python utils/download.py + +## Download gpt weight + +mkdir -p models/gpt2 + Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main -Use pytorch >= 2.6 to load model. \ No newline at end of file +Use pytorch >= 2.6 to load model. + +## Run + +Run: `python.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml \ No newline at end of file diff --git a/STDEN b/STDEN deleted file mode 160000 index e50a1ba..0000000 --- a/STDEN +++ /dev/null @@ -1 +0,0 @@ -Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0c0c649 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,42 @@ +# 核心深度学习框架 +torch +torchvision +torchaudio + +# 科学计算和数据处理 +numpy +pandas +scipy + +# 机器学习工具 +scikit-learn + +# 配置和文件处理 +pyyaml + +# 进度条 +tqdm + +# 图神经网络和距离计算 +fastdtw + +# 微分方程求解器 +torchdiffeq + +# 自然语言处理(用于GPT-2模型) +transformers + +# 数据可视化 +matplotlib + +# 网络请求(用于数据下载) +requests + +# 文件压缩处理 +zipfile + +# Kaggle数据下载 +kagglehub + +# 其他工具 +future diff --git a/utils/download.py b/utils/download.py new file mode 100644 index 0000000..ed7c929 --- /dev/null +++ b/utils/download.py @@ -0,0 +1,146 @@ +import os +import requests +import zipfile +import shutil +import kagglehub # 假设 kagglehub 是一个可用的库 +from tqdm import tqdm + +# 定义文件完整性信息的字典 + + +def check_and_download_data(): + """ + 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 + """ + current_working_dir = os.getcwd() # 获取当前工作目录 + data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下 + + expected_structure = { + "PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"], + "PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"], + "PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"], + "PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"] + } + + current_dir = os.getcwd() # 获取当前工作目录 + missing_adj = False + missing_main_files = False + + # 检查 data 文件夹是否存在 + if not os.path.exists(data_dir) or not os.path.isdir(data_dir): + # print(f"目录 {data_dir} 不存在。") + print("正在下载所有必要的数据文件...") + missing_adj = True + missing_main_files = True + else: + # 检查根目录下的 get_adj.py 文件 + if "get_adj.py" not in os.listdir(data_dir): + # print(f"根目录下缺少文件 get_adj.py。") + missing_adj = True + + # 遍历预期的文件结构 + for subfolder, expected_files in expected_structure.items(): + subfolder_path = os.path.join(data_dir, subfolder) + + # 检查子文件夹是否存在 + if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): + # print(f"子文件夹 {subfolder} 不存在。") + missing_main_files = True + continue + + # 获取子文件夹中的实际文件列表 + actual_files = os.listdir(subfolder_path) + + # 检查是否缺少文件 + for expected_file in expected_files: + if expected_file not in actual_files: + # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") + if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file: + missing_adj = True + else: + missing_main_files = True + + # 根据缺失文件类型调用下载逻辑 + if missing_adj: + download_adj_data(current_dir) + if missing_main_files: + download_kaggle_data(current_dir) + + return True + + +def download_adj_data(current_dir, max_retries=3): + """ + 下载并解压 adj.zip 文件,并显示下载进度条。 + 如果下载失败,最多重试 max_retries 次。 + """ + url = "https://code.zhang-heng.com/static/adj.zip" + retries = 0 + + while retries <= max_retries: + try: + print(f"正在从 {url} 下载邻接矩阵文件...") + response = requests.get(url, stream=True) + + if response.status_code == 200: + total_size = int(response.headers.get('content-length', 0)) + block_size = 1024 # 1KB + t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度") + + zip_file_path = os.path.join(current_dir, "adj.zip") + with open(zip_file_path, 'wb') as f: + for data in response.iter_content(block_size): + f.write(data) + t.update(len(data)) + t.close() + + # print("下载完成,文件已保存到:", zip_file_path) + + if os.path.exists(zip_file_path): + with zipfile.ZipFile(zip_file_path, 'r') as zip_ref: + zip_ref.extractall(current_dir) + # print("数据集已解压到:", current_dir) + os.remove(zip_file_path) # 删除zip文件 + else: + print("未找到下载的zip文件,跳过解压。") + break # 下载成功,退出循环 + else: + print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") + except Exception as e: + print(f"下载或解压数据集时出错: {e}") + print("如果链接无效,请检查URL的合法性或稍后重试。") + + retries += 1 + if retries > max_retries: + raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。") + + +def download_kaggle_data(current_dir): + """ + 下载 KaggleHub 数据集,并将 data 文件夹合并到当前工作目录。 + 如果目标文件夹已存在,会覆盖冲突的文件。 + """ + try: + print("正在下载 PEMS 数据集...") + path = kagglehub.dataset_download("elmahy/pems-dataset") + # print("Path to KaggleHub dataset files:", path) + + if os.path.exists(path): + data_folder_path = os.path.join(path, "data") + if os.path.exists(data_folder_path): + destination_path = os.path.join(current_dir, "data") + + # 使用 shutil.copytree 合并文件夹,覆盖冲突的文件 + shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True) + # print(f"data 文件夹已合并到: {destination_path}") + # else: + # print("未找到 data 文件夹,跳过合并操作。") + # else: + # print("未找到 KaggleHub 数据集路径,跳过处理。") + except Exception as e: + print(f"下载或处理 KaggleHub 数据集时出错: {e}") + + +# 主程序 +if __name__ == "__main__": + check_and_download_data() \ No newline at end of file