Project-I/README.md

135 lines
3.6 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# STDEN项目
时空扩散方程网络Spatio-Temporal Diffusion Equation Network项目用于时空序列预测任务。
## 项目结构
```
Project-I/
├── run.py # 主运行文件
├── configs/ # 配置文件目录
│ ├── stde_gt.yaml # STDE_GT模型配置
│ ├── stde_wrs.yaml # STDE_WRS模型配置
│ └── stde_zgc.yaml # STDE_ZGC模型配置
├── dataloader/ # 数据加载器模块
│ ├── __init__.py
│ └── stden_dataloader.py
├── trainer/ # 训练器模块
│ ├── __init__.py
│ └── stden_trainer.py
├── model/ # 模型模块
│ ├── __init__.py
│ ├── stden_model.py
│ ├── stden_supervisor.py
│ ├── diffeq_solver.py
│ └── ode_func.py
├── lib/ # 工具库
│ ├── __init__.py
│ ├── logger.py
│ ├── utils.py
│ └── metrics.py
├── requirements.txt # 项目依赖
└── README.md # 项目说明
```
## 快速开始
### 1. 安装依赖
```bash
pip install -r requirements.txt
```
### 2. 训练模型
```bash
# 训练STDE_GT模型
python run.py --model_name stde_gt --mode train
# 训练STDE_WRS模型
python run.py --model_name stde_wrs --mode train
# 训练STDE_ZGC模型
python run.py --model_name stde_zgc --mode train
```
### 3. 评估模型
```bash
# 评估STDE_GT模型
python run.py --model_name stde_gt --mode eval --save_pred
# 评估STDE_WRS模型
python run.py --model_name stde_wrs --mode eval --save_pred
# 评估STDE_ZGC模型
python run.py --model_name stde_zgc --mode eval --save_pred
```
## 配置说明
项目使用YAML格式的配置文件主要包含三个部分
### 数据配置 (data)
- `dataset_dir`: 数据集目录路径
- `batch_size`: 训练批处理大小
- `val_batch_size`: 验证批处理大小
- `graph_pkl_filename`: 传感器图邻接矩阵文件
### 模型配置 (model)
- `seq_len`: 输入序列长度
- `horizon`: 预测时间步数
- `input_dim`: 输入特征维度
- `output_dim`: 输出特征维度
- `latent_dim`: 潜在空间维度
- `n_traj_samples`: 轨迹采样数量
- `ode_method`: ODE求解方法
- `rnn_units`: RNN隐藏单元数量
- `gcn_step`: 图卷积步数
### 训练配置 (train)
- `base_lr`: 基础学习率
- `epochs`: 总训练轮数
- `patience`: 早停耐心值
- `optimizer`: 优化器类型
- `max_grad_norm`: 最大梯度范数
## 主要特性
1. **模块化设计**: 清晰的数据加载器、训练器、模型分离
2. **配置驱动**: 使用YAML配置文件易于调整参数
3. **统一接口**: 通过run.py统一调用不同模型
4. **完整日志**: 支持文件和控制台日志输出
5. **TensorBoard支持**: 训练过程可视化
6. **检查点管理**: 自动保存和加载最佳模型
## 支持的模型
- **STDE_GT**: 用于北京GM传感器图数据
- **STDE_WRS**: 用于WRS传感器图数据
- **STDE_ZGC**: 用于ZGC传感器图数据
## 数据格式
项目支持两种数据格式:
1. **BJ格式**: 包含flow.npz文件适用于北京数据集
2. **标准格式**: 包含train.npz、val.npz、test.npz文件
## 日志和输出
- 训练日志保存在`logs/`目录
- 模型检查点保存在`checkpoints/`目录
- TensorBoard日志保存在`runs/`目录
- 预测结果保存在检查点目录的`results/`子目录
## 注意事项
1. 确保数据集目录结构正确
2. 传感器图文件路径配置正确
3. 根据硬件配置调整批处理大小
4. 训练过程中会自动创建必要的目录
## 许可证
本项目采用MIT许可证详见LICENSE文件。