Compare commits
2 Commits
be5e810c54
...
2bb113fed6
| Author | SHA1 | Date |
|---|---|---|
|
|
2bb113fed6 | |
|
|
44517219b7 |
|
|
@ -5,6 +5,7 @@ experiments/
|
||||||
*.csv
|
*.csv
|
||||||
*.npz
|
*.npz
|
||||||
*.pkl
|
*.pkl
|
||||||
|
data/
|
||||||
|
|
||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
依赖包
|
依赖包
|
||||||
|
支持python 3.10以上版本。
|
||||||
|
|
||||||
pip install pyyaml tqdm statsmodels h5py
|
conda create -n trafficwheel python=3.10
|
||||||
pip3 install torch torchvision torchaudio
|
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,11 +3,11 @@ import yaml
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(description='Model Training and Testing')
|
parser = argparse.ArgumentParser(description='Model Training and Testing')
|
||||||
parser.add_argument('--dataset', default='PEMSD7(L)', type=str)
|
parser.add_argument('--dataset', default='PEMSD8', type=str)
|
||||||
parser.add_argument('--mode', default='train', type=str)
|
parser.add_argument('--mode', default='train', type=str)
|
||||||
parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs')
|
parser.add_argument('--device', default='cuda:0', type=str, help='Indices of GPUs')
|
||||||
parser.add_argument('--debug', default=False, type=eval)
|
parser.add_argument('--debug', default=False, type=eval)
|
||||||
parser.add_argument('--model', default='DDGCRN', type=str)
|
parser.add_argument('--model', default='GWN', type=str)
|
||||||
parser.add_argument('--cuda', default=True, type=bool)
|
parser.add_argument('--cuda', default=True, type=bool)
|
||||||
parser.add_argument('--sample', default=1, type=int)
|
parser.add_argument('--sample', default=1, type=int)
|
||||||
parser.add_argument('--emb', default=12, type=int)
|
parser.add_argument('--emb', default=12, type=int)
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
import controldiffeq
|
import model.STGNCDE.controldiffeq
|
||||||
from lib.normalization import normalize_dataset
|
from lib.normalization import normalize_dataset
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import gc
|
import gc
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,146 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import zipfile
|
||||||
|
import shutil
|
||||||
|
import kagglehub # 假设 kagglehub 是一个可用的库
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 定义文件完整性信息的字典
|
||||||
|
|
||||||
|
|
||||||
|
def check_and_download_data():
|
||||||
|
"""
|
||||||
|
检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。
|
||||||
|
"""
|
||||||
|
current_working_dir = os.getcwd() # 获取当前工作目录
|
||||||
|
data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下
|
||||||
|
|
||||||
|
expected_structure = {
|
||||||
|
"PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"],
|
||||||
|
"PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"],
|
||||||
|
"PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"],
|
||||||
|
"PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"]
|
||||||
|
}
|
||||||
|
|
||||||
|
current_dir = os.getcwd() # 获取当前工作目录
|
||||||
|
missing_adj = False
|
||||||
|
missing_main_files = False
|
||||||
|
|
||||||
|
# 检查 data 文件夹是否存在
|
||||||
|
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
|
||||||
|
# print(f"目录 {data_dir} 不存在。")
|
||||||
|
print("正在下载所有必要的数据文件...")
|
||||||
|
missing_adj = True
|
||||||
|
missing_main_files = True
|
||||||
|
else:
|
||||||
|
# 检查根目录下的 get_adj.py 文件
|
||||||
|
if "get_adj.py" not in os.listdir(data_dir):
|
||||||
|
# print(f"根目录下缺少文件 get_adj.py。")
|
||||||
|
missing_adj = True
|
||||||
|
|
||||||
|
# 遍历预期的文件结构
|
||||||
|
for subfolder, expected_files in expected_structure.items():
|
||||||
|
subfolder_path = os.path.join(data_dir, subfolder)
|
||||||
|
|
||||||
|
# 检查子文件夹是否存在
|
||||||
|
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
|
||||||
|
# print(f"子文件夹 {subfolder} 不存在。")
|
||||||
|
missing_main_files = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 获取子文件夹中的实际文件列表
|
||||||
|
actual_files = os.listdir(subfolder_path)
|
||||||
|
|
||||||
|
# 检查是否缺少文件
|
||||||
|
for expected_file in expected_files:
|
||||||
|
if expected_file not in actual_files:
|
||||||
|
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
|
||||||
|
if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file:
|
||||||
|
missing_adj = True
|
||||||
|
else:
|
||||||
|
missing_main_files = True
|
||||||
|
|
||||||
|
# 根据缺失文件类型调用下载逻辑
|
||||||
|
if missing_adj:
|
||||||
|
download_adj_data(current_dir)
|
||||||
|
if missing_main_files:
|
||||||
|
download_kaggle_data(current_dir)
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def download_adj_data(current_dir, max_retries=3):
|
||||||
|
"""
|
||||||
|
下载并解压 adj.zip 文件,并显示下载进度条。
|
||||||
|
如果下载失败,最多重试 max_retries 次。
|
||||||
|
"""
|
||||||
|
url = "https://code.zhang-heng.com/static/adj.zip"
|
||||||
|
retries = 0
|
||||||
|
|
||||||
|
while retries <= max_retries:
|
||||||
|
try:
|
||||||
|
print(f"正在从 {url} 下载邻接矩阵文件...")
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
|
||||||
|
if response.status_code == 200:
|
||||||
|
total_size = int(response.headers.get('content-length', 0))
|
||||||
|
block_size = 1024 # 1KB
|
||||||
|
t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度")
|
||||||
|
|
||||||
|
zip_file_path = os.path.join(current_dir, "adj.zip")
|
||||||
|
with open(zip_file_path, 'wb') as f:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
f.write(data)
|
||||||
|
t.update(len(data))
|
||||||
|
t.close()
|
||||||
|
|
||||||
|
# print("下载完成,文件已保存到:", zip_file_path)
|
||||||
|
|
||||||
|
if os.path.exists(zip_file_path):
|
||||||
|
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
|
||||||
|
zip_ref.extractall(current_dir)
|
||||||
|
# print("数据集已解压到:", current_dir)
|
||||||
|
os.remove(zip_file_path) # 删除zip文件
|
||||||
|
else:
|
||||||
|
print("未找到下载的zip文件,跳过解压。")
|
||||||
|
break # 下载成功,退出循环
|
||||||
|
else:
|
||||||
|
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载或解压数据集时出错: {e}")
|
||||||
|
print("如果链接无效,请检查URL的合法性或稍后重试。")
|
||||||
|
|
||||||
|
retries += 1
|
||||||
|
if retries > max_retries:
|
||||||
|
raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。")
|
||||||
|
|
||||||
|
|
||||||
|
def download_kaggle_data(current_dir):
|
||||||
|
"""
|
||||||
|
下载 KaggleHub 数据集,并将 data 文件夹合并到当前工作目录。
|
||||||
|
如果目标文件夹已存在,会覆盖冲突的文件。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
print("正在下载 KaggleHub 数据集...")
|
||||||
|
path = kagglehub.dataset_download("elmahy/pems-dataset")
|
||||||
|
# print("Path to KaggleHub dataset files:", path)
|
||||||
|
|
||||||
|
if os.path.exists(path):
|
||||||
|
data_folder_path = os.path.join(path, "data")
|
||||||
|
if os.path.exists(data_folder_path):
|
||||||
|
destination_path = os.path.join(current_dir, "data")
|
||||||
|
|
||||||
|
# 使用 shutil.copytree 合并文件夹,覆盖冲突的文件
|
||||||
|
shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True)
|
||||||
|
# print(f"data 文件夹已合并到: {destination_path}")
|
||||||
|
# else:
|
||||||
|
# print("未找到 data 文件夹,跳过合并操作。")
|
||||||
|
# else:
|
||||||
|
# print("未找到 KaggleHub 数据集路径,跳过处理。")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 主程序
|
||||||
|
if __name__ == "__main__":
|
||||||
|
check_and_download_data()
|
||||||
9
run.py
9
run.py
|
|
@ -1,6 +1,11 @@
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
|
|
||||||
|
# 检查数据集完整性
|
||||||
|
from lib.Download_data import check_and_download_data
|
||||||
|
data_complete = check_and_download_data()
|
||||||
|
assert data_complete is not None, "数据集下载失败,请重试!"
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
# import time
|
# import time
|
||||||
|
|
@ -8,9 +13,10 @@ from datetime import datetime
|
||||||
from config.args_parser import parse_args
|
from config.args_parser import parse_args
|
||||||
from lib.initializer import init_model, init_optimizer
|
from lib.initializer import init_model, init_optimizer
|
||||||
from lib.loss_function import get_loss_function
|
from lib.loss_function import get_loss_function
|
||||||
|
|
||||||
from dataloader.loader_selector import get_dataloader
|
from dataloader.loader_selector import get_dataloader
|
||||||
from trainer.trainer_selector import select_trainer
|
from trainer.trainer_selector import select_trainer
|
||||||
import yaml # 需要安装 PyYAML 库:pip install pyyaml
|
import yaml
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|
@ -24,6 +30,7 @@ def main():
|
||||||
args['device'] = 'cpu'
|
args['device'] = 'cpu'
|
||||||
args['model']['device'] = args['device']
|
args['model']['device'] = args['device']
|
||||||
|
|
||||||
|
|
||||||
# Initialize model
|
# Initialize model
|
||||||
model = init_model(args['model'], device=args['device'])
|
model = init_model(args['model'], device=args['device'])
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue