添加requirements.txt依赖文件,更新README和utils/download.py,移除STDEN子模块

This commit is contained in:
czzhangheng 2025-09-04 11:20:17 +08:00
parent e9e3da03d3
commit 42e250d8e1
4 changed files with 200 additions and 2 deletions

View File

@ -4,6 +4,17 @@ Secret Projct
mkdir -p models/gpt2 mkdir -p models/gpt2
## Download dataset
python utils/download.py
## Download gpt weight
mkdir -p models/gpt2
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
Use pytorch >= 2.6 to load model. Use pytorch >= 2.6 to load model.
## Run
Run: `python.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml

1
STDEN

@ -1 +0,0 @@
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2

42
requirements.txt Normal file
View File

@ -0,0 +1,42 @@
# 核心深度学习框架
torch
torchvision
torchaudio
# 科学计算和数据处理
numpy
pandas
scipy
# 机器学习工具
scikit-learn
# 配置和文件处理
pyyaml
# 进度条
tqdm
# 图神经网络和距离计算
fastdtw
# 微分方程求解器
torchdiffeq
# 自然语言处理用于GPT-2模型
transformers
# 数据可视化
matplotlib
# 网络请求(用于数据下载)
requests
# 文件压缩处理
zipfile
# Kaggle数据下载
kagglehub
# 其他工具
future

146
utils/download.py Normal file
View File

@ -0,0 +1,146 @@
import os
import requests
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
from tqdm import tqdm
# 定义文件完整性信息的字典
def check_and_download_data():
"""
检查 data 文件夹的完整性并根据缺失文件类型下载相应数据
"""
current_working_dir = os.getcwd() # 获取当前工作目录
data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下
expected_structure = {
"PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"],
"PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"],
"PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"],
"PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 检查根目录下的 get_adj.py 文件
if "get_adj.py" not in os.listdir(data_dir):
# print(f"根目录下缺少文件 get_adj.py。")
missing_adj = True
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue
# 获取子文件夹中的实际文件列表
actual_files = os.listdir(subfolder_path)
# 检查是否缺少文件
for expected_file in expected_files:
if expected_file not in actual_files:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file:
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑
if missing_adj:
download_adj_data(current_dir)
if missing_main_files:
download_kaggle_data(current_dir)
return True
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件并显示下载进度条
如果下载失败最多重试 max_retries
"""
url = "https://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, 'wb') as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。")
def download_kaggle_data(current_dir):
"""
下载 KaggleHub 数据集并将 data 文件夹合并到当前工作目录
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
data_folder_path = os.path.join(path, "data")
if os.path.exists(data_folder_path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 合并文件夹,覆盖冲突的文件
shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True)
# print(f"data 文件夹已合并到: {destination_path}")
# else:
# print("未找到 data 文件夹,跳过合并操作。")
# else:
# print("未找到 KaggleHub 数据集路径,跳过处理。")
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
# 主程序
if __name__ == "__main__":
check_and_download_data()