3.6 KiB
3.6 KiB
STDEN项目
时空扩散方程网络(Spatio-Temporal Diffusion Equation Network)项目,用于时空序列预测任务。
项目结构
Project-I/
├── run.py # 主运行文件
├── configs/ # 配置文件目录
│ ├── stde_gt.yaml # STDE_GT模型配置
│ ├── stde_wrs.yaml # STDE_WRS模型配置
│ └── stde_zgc.yaml # STDE_ZGC模型配置
├── dataloader/ # 数据加载器模块
│ ├── __init__.py
│ └── stden_dataloader.py
├── trainer/ # 训练器模块
│ ├── __init__.py
│ └── stden_trainer.py
├── model/ # 模型模块
│ ├── __init__.py
│ ├── stden_model.py
│ ├── stden_supervisor.py
│ ├── diffeq_solver.py
│ └── ode_func.py
├── lib/ # 工具库
│ ├── __init__.py
│ ├── logger.py
│ ├── utils.py
│ └── metrics.py
├── requirements.txt # 项目依赖
└── README.md # 项目说明
快速开始
1. 安装依赖
pip install -r requirements.txt
2. 训练模型
# 训练STDE_GT模型
python run.py --model_name stde_gt --mode train
# 训练STDE_WRS模型
python run.py --model_name stde_wrs --mode train
# 训练STDE_ZGC模型
python run.py --model_name stde_zgc --mode train
3. 评估模型
# 评估STDE_GT模型
python run.py --model_name stde_gt --mode eval --save_pred
# 评估STDE_WRS模型
python run.py --model_name stde_wrs --mode eval --save_pred
# 评估STDE_ZGC模型
python run.py --model_name stde_zgc --mode eval --save_pred
配置说明
项目使用YAML格式的配置文件,主要包含三个部分:
数据配置 (data)
dataset_dir: 数据集目录路径batch_size: 训练批处理大小val_batch_size: 验证批处理大小graph_pkl_filename: 传感器图邻接矩阵文件
模型配置 (model)
seq_len: 输入序列长度horizon: 预测时间步数input_dim: 输入特征维度output_dim: 输出特征维度latent_dim: 潜在空间维度n_traj_samples: 轨迹采样数量ode_method: ODE求解方法rnn_units: RNN隐藏单元数量gcn_step: 图卷积步数
训练配置 (train)
base_lr: 基础学习率epochs: 总训练轮数patience: 早停耐心值optimizer: 优化器类型max_grad_norm: 最大梯度范数
主要特性
- 模块化设计: 清晰的数据加载器、训练器、模型分离
- 配置驱动: 使用YAML配置文件,易于调整参数
- 统一接口: 通过run.py统一调用不同模型
- 完整日志: 支持文件和控制台日志输出
- TensorBoard支持: 训练过程可视化
- 检查点管理: 自动保存和加载最佳模型
支持的模型
- STDE_GT: 用于北京GM传感器图数据
- STDE_WRS: 用于WRS传感器图数据
- STDE_ZGC: 用于ZGC传感器图数据
数据格式
项目支持两种数据格式:
- BJ格式: 包含flow.npz文件,适用于北京数据集
- 标准格式: 包含train.npz、val.npz、test.npz文件
日志和输出
- 训练日志保存在
logs/目录 - 模型检查点保存在
checkpoints/目录 - TensorBoard日志保存在
runs/目录 - 预测结果保存在检查点目录的
results/子目录
注意事项
- 确保数据集目录结构正确
- 传感器图文件路径配置正确
- 根据硬件配置调整批处理大小
- 训练过程中会自动创建必要的目录
许可证
本项目采用MIT许可证,详见LICENSE文件。