# STDEN项目 时空扩散方程网络(Spatio-Temporal Diffusion Equation Network)项目,用于时空序列预测任务。 ## 项目结构 ``` Project-I/ ├── run.py # 主运行文件 ├── configs/ # 配置文件目录 │ ├── stde_gt.yaml # STDE_GT模型配置 │ ├── stde_wrs.yaml # STDE_WRS模型配置 │ └── stde_zgc.yaml # STDE_ZGC模型配置 ├── dataloader/ # 数据加载器模块 │ ├── __init__.py │ └── stden_dataloader.py ├── trainer/ # 训练器模块 │ ├── __init__.py │ └── stden_trainer.py ├── model/ # 模型模块 │ ├── __init__.py │ ├── stden_model.py │ ├── stden_supervisor.py │ ├── diffeq_solver.py │ └── ode_func.py ├── lib/ # 工具库 │ ├── __init__.py │ ├── logger.py │ ├── utils.py │ └── metrics.py ├── requirements.txt # 项目依赖 └── README.md # 项目说明 ``` ## 快速开始 ### 1. 安装依赖 ```bash pip install -r requirements.txt ``` ### 2. 训练模型 ```bash # 训练STDE_GT模型 python run.py --model_name stde_gt --mode train # 训练STDE_WRS模型 python run.py --model_name stde_wrs --mode train # 训练STDE_ZGC模型 python run.py --model_name stde_zgc --mode train ``` ### 3. 评估模型 ```bash # 评估STDE_GT模型 python run.py --model_name stde_gt --mode eval --save_pred # 评估STDE_WRS模型 python run.py --model_name stde_wrs --mode eval --save_pred # 评估STDE_ZGC模型 python run.py --model_name stde_zgc --mode eval --save_pred ``` ## 配置说明 项目使用YAML格式的配置文件,主要包含三个部分: ### 数据配置 (data) - `dataset_dir`: 数据集目录路径 - `batch_size`: 训练批处理大小 - `val_batch_size`: 验证批处理大小 - `graph_pkl_filename`: 传感器图邻接矩阵文件 ### 模型配置 (model) - `seq_len`: 输入序列长度 - `horizon`: 预测时间步数 - `input_dim`: 输入特征维度 - `output_dim`: 输出特征维度 - `latent_dim`: 潜在空间维度 - `n_traj_samples`: 轨迹采样数量 - `ode_method`: ODE求解方法 - `rnn_units`: RNN隐藏单元数量 - `gcn_step`: 图卷积步数 ### 训练配置 (train) - `base_lr`: 基础学习率 - `epochs`: 总训练轮数 - `patience`: 早停耐心值 - `optimizer`: 优化器类型 - `max_grad_norm`: 最大梯度范数 ## 主要特性 1. **模块化设计**: 清晰的数据加载器、训练器、模型分离 2. **配置驱动**: 使用YAML配置文件,易于调整参数 3. **统一接口**: 通过run.py统一调用不同模型 4. **完整日志**: 支持文件和控制台日志输出 5. **TensorBoard支持**: 训练过程可视化 6. **检查点管理**: 自动保存和加载最佳模型 ## 支持的模型 - **STDE_GT**: 用于北京GM传感器图数据 - **STDE_WRS**: 用于WRS传感器图数据 - **STDE_ZGC**: 用于ZGC传感器图数据 ## 数据格式 项目支持两种数据格式: 1. **BJ格式**: 包含flow.npz文件,适用于北京数据集 2. **标准格式**: 包含train.npz、val.npz、test.npz文件 ## 日志和输出 - 训练日志保存在`logs/`目录 - 模型检查点保存在`checkpoints/`目录 - TensorBoard日志保存在`runs/`目录 - 预测结果保存在检查点目录的`results/`子目录 ## 注意事项 1. 确保数据集目录结构正确 2. 传感器图文件路径配置正确 3. 根据硬件配置调整批处理大小 4. 训练过程中会自动创建必要的目录 ## 许可证 本项目采用MIT许可证,详见LICENSE文件。