158 lines
4.1 KiB
Markdown
158 lines
4.1 KiB
Markdown
# STEP模型适配说明
|
||
|
||
## 概述
|
||
|
||
STEP (Pre-training Enhanced Spatial-temporal Graph Neural Network) 是一个基于预训练的时空图神经网络,用于多变量时间序列预测。该模型结合了TSFormer预训练模型和GraphWaveNet后端模型,通过离散图学习来增强下游时空图神经网络。
|
||
|
||
## 模型架构
|
||
|
||
STEP模型包含以下主要组件:
|
||
|
||
1. **TSFormer**: 基于Transformer的预训练模型,用于学习长期时间序列表示
|
||
2. **GraphWaveNet**: 后端时空图神经网络,用于短期预测
|
||
3. **离散图学习**: 动态学习节点间的依赖关系图
|
||
|
||
## 文件结构
|
||
|
||
```
|
||
model/STEP/
|
||
├── __init__.py
|
||
├── STEP.py # 主模型文件
|
||
├── tsformer.py # TSFormer模型
|
||
├── graphwavenet.py # GraphWaveNet模型
|
||
├── discrete_graph_learning.py # 离散图学习模块
|
||
├── similarity.py # 相似度计算
|
||
├── step_loss.py # 损失函数
|
||
└── tsformer_components/ # TSFormer组件
|
||
├── __init__.py
|
||
├── patch.py
|
||
├── mask.py
|
||
├── positional_encoding.py
|
||
└── transformer_layers.py
|
||
|
||
trainer/
|
||
└── STEP_Trainer.py # STEP专用训练器
|
||
|
||
dataloader/
|
||
└── STEPdataloader.py # STEP数据加载器
|
||
|
||
config/STEP/
|
||
├── STEP_PEMS04.yaml # PEMS04数据集配置
|
||
├── STEP_PEMS03.yaml # PEMS03数据集配置
|
||
└── STEP_METR-LA.yaml # METR-LA数据集配置
|
||
```
|
||
|
||
## 使用方法
|
||
|
||
### 1. 测试模型
|
||
|
||
运行测试脚本验证模型是否正常工作:
|
||
|
||
```bash
|
||
python test_step.py
|
||
```
|
||
|
||
### 2. 训练模型
|
||
|
||
使用默认配置训练STEP模型:
|
||
|
||
```bash
|
||
python train_step.py
|
||
```
|
||
|
||
使用自定义配置文件:
|
||
|
||
```bash
|
||
python train_step.py --config config/STEP/STEP_PEMS03.yaml
|
||
```
|
||
|
||
指定训练轮数:
|
||
|
||
```bash
|
||
python train_step.py --config config/STEP/STEP_PEMS04.yaml --epochs 50
|
||
```
|
||
|
||
### 3. 在现有框架中使用
|
||
|
||
STEP模型已经集成到现有的模型选择器中,可以通过以下方式使用:
|
||
|
||
```python
|
||
from model.model_selector import model_selector
|
||
|
||
# 创建STEP模型
|
||
config = {
|
||
'type': 'STEP',
|
||
'dataset_name': 'PEMS04',
|
||
'num_nodes': 307,
|
||
# ... 其他参数
|
||
}
|
||
model = model_selector(config)
|
||
```
|
||
|
||
## 配置参数
|
||
|
||
### 模型参数
|
||
|
||
- `dataset_name`: 数据集名称 (PEMS04, PEMS03, METR-LA等)
|
||
- `num_nodes`: 节点数量
|
||
- `lag`: 输入序列长度
|
||
- `horizon`: 预测序列长度
|
||
- `tsformer_args`: TSFormer模型参数
|
||
- `backend_args`: GraphWaveNet后端参数
|
||
- `dgl_args`: 离散图学习参数
|
||
|
||
### 训练参数
|
||
|
||
- `epochs`: 训练轮数
|
||
- `lr`: 学习率
|
||
- `weight_decay`: 权重衰减
|
||
- `batch_size`: 批次大小
|
||
- `clip_grad_norm`: 梯度裁剪
|
||
|
||
## 数据集支持
|
||
|
||
STEP模型支持以下数据集:
|
||
|
||
- **PEMS04**: 307个节点,交通流量数据
|
||
- **PEMS03**: 358个节点,交通流量数据
|
||
- **METR-LA**: 207个节点,交通速度数据
|
||
- **METR-BAY**: 325个节点,交通速度数据
|
||
- **PEMS07**: 883个节点,交通流量数据
|
||
- **PEMS08**: 170个节点,交通流量数据
|
||
|
||
## 性能统计
|
||
|
||
STEP模型训练器包含完整的性能统计功能:
|
||
|
||
- **显存使用**: GPU和CPU内存占用监控
|
||
- **训练效率**: 每步训练时间统计
|
||
- **推理效率**: 每步推理时间统计
|
||
- **总体统计**: 总训练时间、迭代次数等
|
||
|
||
## 注意事项
|
||
|
||
1. **预训练模型**: STEP模型需要预训练的TSFormer模型,如果找不到预训练模型文件,会显示警告并使用随机初始化的权重。
|
||
|
||
2. **数据文件**: 离散图学习模块需要特定的数据文件,如果找不到会使用随机数据作为占位符。
|
||
|
||
3. **内存使用**: STEP模型包含多个组件,可能需要较大的GPU内存。
|
||
|
||
4. **训练时间**: 由于模型复杂度较高,训练时间可能较长。
|
||
|
||
## 故障排除
|
||
|
||
如果遇到问题,请检查:
|
||
|
||
1. 数据文件是否存在且格式正确
|
||
2. 预训练模型文件路径是否正确
|
||
3. GPU内存是否足够
|
||
4. 依赖包是否安装完整
|
||
|
||
## 扩展
|
||
|
||
要添加新的数据集支持,需要:
|
||
|
||
1. 在`discrete_graph_learning.py`中添加数据集配置
|
||
2. 创建对应的配置文件
|
||
3. 确保数据文件格式正确
|