135 lines
3.6 KiB
Markdown
135 lines
3.6 KiB
Markdown
# STDEN项目
|
||
|
||
时空扩散方程网络(Spatio-Temporal Diffusion Equation Network)项目,用于时空序列预测任务。
|
||
|
||
## 项目结构
|
||
|
||
```
|
||
Project-I/
|
||
├── run.py # 主运行文件
|
||
├── configs/ # 配置文件目录
|
||
│ ├── stde_gt.yaml # STDE_GT模型配置
|
||
│ ├── stde_wrs.yaml # STDE_WRS模型配置
|
||
│ └── stde_zgc.yaml # STDE_ZGC模型配置
|
||
├── dataloader/ # 数据加载器模块
|
||
│ ├── __init__.py
|
||
│ └── stden_dataloader.py
|
||
├── trainer/ # 训练器模块
|
||
│ ├── __init__.py
|
||
│ └── stden_trainer.py
|
||
├── model/ # 模型模块
|
||
│ ├── __init__.py
|
||
│ ├── stden_model.py
|
||
│ ├── stden_supervisor.py
|
||
│ ├── diffeq_solver.py
|
||
│ └── ode_func.py
|
||
├── lib/ # 工具库
|
||
│ ├── __init__.py
|
||
│ ├── logger.py
|
||
│ ├── utils.py
|
||
│ └── metrics.py
|
||
├── requirements.txt # 项目依赖
|
||
└── README.md # 项目说明
|
||
```
|
||
|
||
## 快速开始
|
||
|
||
### 1. 安装依赖
|
||
|
||
```bash
|
||
pip install -r requirements.txt
|
||
```
|
||
|
||
### 2. 训练模型
|
||
|
||
```bash
|
||
# 训练STDE_GT模型
|
||
python run.py --model_name stde_gt --mode train
|
||
|
||
# 训练STDE_WRS模型
|
||
python run.py --model_name stde_wrs --mode train
|
||
|
||
# 训练STDE_ZGC模型
|
||
python run.py --model_name stde_zgc --mode train
|
||
```
|
||
|
||
### 3. 评估模型
|
||
|
||
```bash
|
||
# 评估STDE_GT模型
|
||
python run.py --model_name stde_gt --mode eval --save_pred
|
||
|
||
# 评估STDE_WRS模型
|
||
python run.py --model_name stde_wrs --mode eval --save_pred
|
||
|
||
# 评估STDE_ZGC模型
|
||
python run.py --model_name stde_zgc --mode eval --save_pred
|
||
```
|
||
|
||
## 配置说明
|
||
|
||
项目使用YAML格式的配置文件,主要包含三个部分:
|
||
|
||
### 数据配置 (data)
|
||
- `dataset_dir`: 数据集目录路径
|
||
- `batch_size`: 训练批处理大小
|
||
- `val_batch_size`: 验证批处理大小
|
||
- `graph_pkl_filename`: 传感器图邻接矩阵文件
|
||
|
||
### 模型配置 (model)
|
||
- `seq_len`: 输入序列长度
|
||
- `horizon`: 预测时间步数
|
||
- `input_dim`: 输入特征维度
|
||
- `output_dim`: 输出特征维度
|
||
- `latent_dim`: 潜在空间维度
|
||
- `n_traj_samples`: 轨迹采样数量
|
||
- `ode_method`: ODE求解方法
|
||
- `rnn_units`: RNN隐藏单元数量
|
||
- `gcn_step`: 图卷积步数
|
||
|
||
### 训练配置 (train)
|
||
- `base_lr`: 基础学习率
|
||
- `epochs`: 总训练轮数
|
||
- `patience`: 早停耐心值
|
||
- `optimizer`: 优化器类型
|
||
- `max_grad_norm`: 最大梯度范数
|
||
|
||
## 主要特性
|
||
|
||
1. **模块化设计**: 清晰的数据加载器、训练器、模型分离
|
||
2. **配置驱动**: 使用YAML配置文件,易于调整参数
|
||
3. **统一接口**: 通过run.py统一调用不同模型
|
||
4. **完整日志**: 支持文件和控制台日志输出
|
||
5. **TensorBoard支持**: 训练过程可视化
|
||
6. **检查点管理**: 自动保存和加载最佳模型
|
||
|
||
## 支持的模型
|
||
|
||
- **STDE_GT**: 用于北京GM传感器图数据
|
||
- **STDE_WRS**: 用于WRS传感器图数据
|
||
- **STDE_ZGC**: 用于ZGC传感器图数据
|
||
|
||
## 数据格式
|
||
|
||
项目支持两种数据格式:
|
||
|
||
1. **BJ格式**: 包含flow.npz文件,适用于北京数据集
|
||
2. **标准格式**: 包含train.npz、val.npz、test.npz文件
|
||
|
||
## 日志和输出
|
||
|
||
- 训练日志保存在`logs/`目录
|
||
- 模型检查点保存在`checkpoints/`目录
|
||
- TensorBoard日志保存在`runs/`目录
|
||
- 预测结果保存在检查点目录的`results/`子目录
|
||
|
||
## 注意事项
|
||
|
||
1. 确保数据集目录结构正确
|
||
2. 传感器图文件路径配置正确
|
||
3. 根据硬件配置调整批处理大小
|
||
4. 训练过程中会自动创建必要的目录
|
||
|
||
## 许可证
|
||
|
||
本项目采用MIT许可证,详见LICENSE文件。 |