# STEP模型适配说明 ## 概述 STEP (Pre-training Enhanced Spatial-temporal Graph Neural Network) 是一个基于预训练的时空图神经网络,用于多变量时间序列预测。该模型结合了TSFormer预训练模型和GraphWaveNet后端模型,通过离散图学习来增强下游时空图神经网络。 ## 模型架构 STEP模型包含以下主要组件: 1. **TSFormer**: 基于Transformer的预训练模型,用于学习长期时间序列表示 2. **GraphWaveNet**: 后端时空图神经网络,用于短期预测 3. **离散图学习**: 动态学习节点间的依赖关系图 ## 文件结构 ``` model/STEP/ ├── __init__.py ├── STEP.py # 主模型文件 ├── tsformer.py # TSFormer模型 ├── graphwavenet.py # GraphWaveNet模型 ├── discrete_graph_learning.py # 离散图学习模块 ├── similarity.py # 相似度计算 ├── step_loss.py # 损失函数 └── tsformer_components/ # TSFormer组件 ├── __init__.py ├── patch.py ├── mask.py ├── positional_encoding.py └── transformer_layers.py trainer/ └── STEP_Trainer.py # STEP专用训练器 dataloader/ └── STEPdataloader.py # STEP数据加载器 config/STEP/ ├── STEP_PEMS04.yaml # PEMS04数据集配置 ├── STEP_PEMS03.yaml # PEMS03数据集配置 └── STEP_METR-LA.yaml # METR-LA数据集配置 ``` ## 使用方法 ### 1. 测试模型 运行测试脚本验证模型是否正常工作: ```bash python test_step.py ``` ### 2. 训练模型 使用默认配置训练STEP模型: ```bash python train_step.py ``` 使用自定义配置文件: ```bash python train_step.py --config config/STEP/STEP_PEMS03.yaml ``` 指定训练轮数: ```bash python train_step.py --config config/STEP/STEP_PEMS04.yaml --epochs 50 ``` ### 3. 在现有框架中使用 STEP模型已经集成到现有的模型选择器中,可以通过以下方式使用: ```python from model.model_selector import model_selector # 创建STEP模型 config = { 'type': 'STEP', 'dataset_name': 'PEMS04', 'num_nodes': 307, # ... 其他参数 } model = model_selector(config) ``` ## 配置参数 ### 模型参数 - `dataset_name`: 数据集名称 (PEMS04, PEMS03, METR-LA等) - `num_nodes`: 节点数量 - `lag`: 输入序列长度 - `horizon`: 预测序列长度 - `tsformer_args`: TSFormer模型参数 - `backend_args`: GraphWaveNet后端参数 - `dgl_args`: 离散图学习参数 ### 训练参数 - `epochs`: 训练轮数 - `lr`: 学习率 - `weight_decay`: 权重衰减 - `batch_size`: 批次大小 - `clip_grad_norm`: 梯度裁剪 ## 数据集支持 STEP模型支持以下数据集: - **PEMS04**: 307个节点,交通流量数据 - **PEMS03**: 358个节点,交通流量数据 - **METR-LA**: 207个节点,交通速度数据 - **METR-BAY**: 325个节点,交通速度数据 - **PEMS07**: 883个节点,交通流量数据 - **PEMS08**: 170个节点,交通流量数据 ## 性能统计 STEP模型训练器包含完整的性能统计功能: - **显存使用**: GPU和CPU内存占用监控 - **训练效率**: 每步训练时间统计 - **推理效率**: 每步推理时间统计 - **总体统计**: 总训练时间、迭代次数等 ## 注意事项 1. **预训练模型**: STEP模型需要预训练的TSFormer模型,如果找不到预训练模型文件,会显示警告并使用随机初始化的权重。 2. **数据文件**: 离散图学习模块需要特定的数据文件,如果找不到会使用随机数据作为占位符。 3. **内存使用**: STEP模型包含多个组件,可能需要较大的GPU内存。 4. **训练时间**: 由于模型复杂度较高,训练时间可能较长。 ## 故障排除 如果遇到问题,请检查: 1. 数据文件是否存在且格式正确 2. 预训练模型文件路径是否正确 3. GPU内存是否足够 4. 依赖包是否安装完整 ## 扩展 要添加新的数据集支持,需要: 1. 在`discrete_graph_learning.py`中添加数据集配置 2. 创建对应的配置文件 3. 确保数据文件格式正确