4.1 KiB
4.1 KiB
STEP模型适配说明
概述
STEP (Pre-training Enhanced Spatial-temporal Graph Neural Network) 是一个基于预训练的时空图神经网络,用于多变量时间序列预测。该模型结合了TSFormer预训练模型和GraphWaveNet后端模型,通过离散图学习来增强下游时空图神经网络。
模型架构
STEP模型包含以下主要组件:
- TSFormer: 基于Transformer的预训练模型,用于学习长期时间序列表示
- GraphWaveNet: 后端时空图神经网络,用于短期预测
- 离散图学习: 动态学习节点间的依赖关系图
文件结构
model/STEP/
├── __init__.py
├── STEP.py # 主模型文件
├── tsformer.py # TSFormer模型
├── graphwavenet.py # GraphWaveNet模型
├── discrete_graph_learning.py # 离散图学习模块
├── similarity.py # 相似度计算
├── step_loss.py # 损失函数
└── tsformer_components/ # TSFormer组件
├── __init__.py
├── patch.py
├── mask.py
├── positional_encoding.py
└── transformer_layers.py
trainer/
└── STEP_Trainer.py # STEP专用训练器
dataloader/
└── STEPdataloader.py # STEP数据加载器
config/STEP/
├── STEP_PEMS04.yaml # PEMS04数据集配置
├── STEP_PEMS03.yaml # PEMS03数据集配置
└── STEP_METR-LA.yaml # METR-LA数据集配置
使用方法
1. 测试模型
运行测试脚本验证模型是否正常工作:
python test_step.py
2. 训练模型
使用默认配置训练STEP模型:
python train_step.py
使用自定义配置文件:
python train_step.py --config config/STEP/STEP_PEMS03.yaml
指定训练轮数:
python train_step.py --config config/STEP/STEP_PEMS04.yaml --epochs 50
3. 在现有框架中使用
STEP模型已经集成到现有的模型选择器中,可以通过以下方式使用:
from model.model_selector import model_selector
# 创建STEP模型
config = {
'type': 'STEP',
'dataset_name': 'PEMS04',
'num_nodes': 307,
# ... 其他参数
}
model = model_selector(config)
配置参数
模型参数
dataset_name: 数据集名称 (PEMS04, PEMS03, METR-LA等)num_nodes: 节点数量lag: 输入序列长度horizon: 预测序列长度tsformer_args: TSFormer模型参数backend_args: GraphWaveNet后端参数dgl_args: 离散图学习参数
训练参数
epochs: 训练轮数lr: 学习率weight_decay: 权重衰减batch_size: 批次大小clip_grad_norm: 梯度裁剪
数据集支持
STEP模型支持以下数据集:
- PEMS04: 307个节点,交通流量数据
- PEMS03: 358个节点,交通流量数据
- METR-LA: 207个节点,交通速度数据
- METR-BAY: 325个节点,交通速度数据
- PEMS07: 883个节点,交通流量数据
- PEMS08: 170个节点,交通流量数据
性能统计
STEP模型训练器包含完整的性能统计功能:
- 显存使用: GPU和CPU内存占用监控
- 训练效率: 每步训练时间统计
- 推理效率: 每步推理时间统计
- 总体统计: 总训练时间、迭代次数等
注意事项
-
预训练模型: STEP模型需要预训练的TSFormer模型,如果找不到预训练模型文件,会显示警告并使用随机初始化的权重。
-
数据文件: 离散图学习模块需要特定的数据文件,如果找不到会使用随机数据作为占位符。
-
内存使用: STEP模型包含多个组件,可能需要较大的GPU内存。
-
训练时间: 由于模型复杂度较高,训练时间可能较长。
故障排除
如果遇到问题,请检查:
- 数据文件是否存在且格式正确
- 预训练模型文件路径是否正确
- GPU内存是否足够
- 依赖包是否安装完整
扩展
要添加新的数据集支持,需要:
- 在
discrete_graph_learning.py中添加数据集配置 - 创建对应的配置文件
- 确保数据文件格式正确