Secret Projct
Go to file
harry.zhang ee955e9481 STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
configs STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
dataloader STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
examples STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
lib STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
model STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
trainer STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
.gitignore STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
LICENSE Initial commit 2025-09-01 11:03:24 +08:00
README.md STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
requirements.txt STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00
run.py STDEN工程化到当前项目 2025-09-01 11:52:33 +08:00

README.md

STDEN项目

时空扩散方程网络Spatio-Temporal Diffusion Equation Network项目用于时空序列预测任务。

项目结构

Project-I/
├── run.py                 # 主运行文件
├── configs/              # 配置文件目录
│   ├── stde_gt.yaml     # STDE_GT模型配置
│   ├── stde_wrs.yaml    # STDE_WRS模型配置
│   └── stde_zgc.yaml    # STDE_ZGC模型配置
├── dataloader/           # 数据加载器模块
│   ├── __init__.py
│   └── stden_dataloader.py
├── trainer/              # 训练器模块
│   ├── __init__.py
│   └── stden_trainer.py
├── model/                # 模型模块
│   ├── __init__.py
│   ├── stden_model.py
│   ├── stden_supervisor.py
│   ├── diffeq_solver.py
│   └── ode_func.py
├── lib/                  # 工具库
│   ├── __init__.py
│   ├── logger.py
│   ├── utils.py
│   └── metrics.py
├── requirements.txt      # 项目依赖
└── README.md            # 项目说明

快速开始

1. 安装依赖

pip install -r requirements.txt

2. 训练模型

# 训练STDE_GT模型
python run.py --model_name stde_gt --mode train

# 训练STDE_WRS模型
python run.py --model_name stde_wrs --mode train

# 训练STDE_ZGC模型
python run.py --model_name stde_zgc --mode train

3. 评估模型

# 评估STDE_GT模型
python run.py --model_name stde_gt --mode eval --save_pred

# 评估STDE_WRS模型
python run.py --model_name stde_wrs --mode eval --save_pred

# 评估STDE_ZGC模型
python run.py --model_name stde_zgc --mode eval --save_pred

配置说明

项目使用YAML格式的配置文件主要包含三个部分

数据配置 (data)

  • dataset_dir: 数据集目录路径
  • batch_size: 训练批处理大小
  • val_batch_size: 验证批处理大小
  • graph_pkl_filename: 传感器图邻接矩阵文件

模型配置 (model)

  • seq_len: 输入序列长度
  • horizon: 预测时间步数
  • input_dim: 输入特征维度
  • output_dim: 输出特征维度
  • latent_dim: 潜在空间维度
  • n_traj_samples: 轨迹采样数量
  • ode_method: ODE求解方法
  • rnn_units: RNN隐藏单元数量
  • gcn_step: 图卷积步数

训练配置 (train)

  • base_lr: 基础学习率
  • epochs: 总训练轮数
  • patience: 早停耐心值
  • optimizer: 优化器类型
  • max_grad_norm: 最大梯度范数

主要特性

  1. 模块化设计: 清晰的数据加载器、训练器、模型分离
  2. 配置驱动: 使用YAML配置文件易于调整参数
  3. 统一接口: 通过run.py统一调用不同模型
  4. 完整日志: 支持文件和控制台日志输出
  5. TensorBoard支持: 训练过程可视化
  6. 检查点管理: 自动保存和加载最佳模型

支持的模型

  • STDE_GT: 用于北京GM传感器图数据
  • STDE_WRS: 用于WRS传感器图数据
  • STDE_ZGC: 用于ZGC传感器图数据

数据格式

项目支持两种数据格式:

  1. BJ格式: 包含flow.npz文件适用于北京数据集
  2. 标准格式: 包含train.npz、val.npz、test.npz文件

日志和输出

  • 训练日志保存在logs/目录
  • 模型检查点保存在checkpoints/目录
  • TensorBoard日志保存在runs/目录
  • 预测结果保存在检查点目录的results/子目录

注意事项

  1. 确保数据集目录结构正确
  2. 传感器图文件路径配置正确
  3. 根据硬件配置调整批处理大小
  4. 训练过程中会自动创建必要的目录

许可证

本项目采用MIT许可证详见LICENSE文件。