TrafficWheel/STEP_README.md

4.1 KiB
Raw Blame History

STEP模型适配说明

概述

STEP (Pre-training Enhanced Spatial-temporal Graph Neural Network) 是一个基于预训练的时空图神经网络用于多变量时间序列预测。该模型结合了TSFormer预训练模型和GraphWaveNet后端模型通过离散图学习来增强下游时空图神经网络。

模型架构

STEP模型包含以下主要组件

  1. TSFormer: 基于Transformer的预训练模型用于学习长期时间序列表示
  2. GraphWaveNet: 后端时空图神经网络,用于短期预测
  3. 离散图学习: 动态学习节点间的依赖关系图

文件结构

model/STEP/
├── __init__.py
├── STEP.py                    # 主模型文件
├── tsformer.py               # TSFormer模型
├── graphwavenet.py           # GraphWaveNet模型
├── discrete_graph_learning.py # 离散图学习模块
├── similarity.py             # 相似度计算
├── step_loss.py              # 损失函数
└── tsformer_components/      # TSFormer组件
    ├── __init__.py
    ├── patch.py
    ├── mask.py
    ├── positional_encoding.py
    └── transformer_layers.py

trainer/
└── STEP_Trainer.py           # STEP专用训练器

dataloader/
└── STEPdataloader.py         # STEP数据加载器

config/STEP/
├── STEP_PEMS04.yaml          # PEMS04数据集配置
├── STEP_PEMS03.yaml          # PEMS03数据集配置
└── STEP_METR-LA.yaml         # METR-LA数据集配置

使用方法

1. 测试模型

运行测试脚本验证模型是否正常工作:

python test_step.py

2. 训练模型

使用默认配置训练STEP模型

python train_step.py

使用自定义配置文件:

python train_step.py --config config/STEP/STEP_PEMS03.yaml

指定训练轮数:

python train_step.py --config config/STEP/STEP_PEMS04.yaml --epochs 50

3. 在现有框架中使用

STEP模型已经集成到现有的模型选择器中可以通过以下方式使用

from model.model_selector import model_selector

# 创建STEP模型
config = {
    'type': 'STEP',
    'dataset_name': 'PEMS04',
    'num_nodes': 307,
    # ... 其他参数
}
model = model_selector(config)

配置参数

模型参数

  • dataset_name: 数据集名称 (PEMS04, PEMS03, METR-LA等)
  • num_nodes: 节点数量
  • lag: 输入序列长度
  • horizon: 预测序列长度
  • tsformer_args: TSFormer模型参数
  • backend_args: GraphWaveNet后端参数
  • dgl_args: 离散图学习参数

训练参数

  • epochs: 训练轮数
  • lr: 学习率
  • weight_decay: 权重衰减
  • batch_size: 批次大小
  • clip_grad_norm: 梯度裁剪

数据集支持

STEP模型支持以下数据集

  • PEMS04: 307个节点交通流量数据
  • PEMS03: 358个节点交通流量数据
  • METR-LA: 207个节点交通速度数据
  • METR-BAY: 325个节点交通速度数据
  • PEMS07: 883个节点交通流量数据
  • PEMS08: 170个节点交通流量数据

性能统计

STEP模型训练器包含完整的性能统计功能

  • 显存使用: GPU和CPU内存占用监控
  • 训练效率: 每步训练时间统计
  • 推理效率: 每步推理时间统计
  • 总体统计: 总训练时间、迭代次数等

注意事项

  1. 预训练模型: STEP模型需要预训练的TSFormer模型如果找不到预训练模型文件会显示警告并使用随机初始化的权重。

  2. 数据文件: 离散图学习模块需要特定的数据文件,如果找不到会使用随机数据作为占位符。

  3. 内存使用: STEP模型包含多个组件可能需要较大的GPU内存。

  4. 训练时间: 由于模型复杂度较高,训练时间可能较长。

故障排除

如果遇到问题,请检查:

  1. 数据文件是否存在且格式正确
  2. 预训练模型文件路径是否正确
  3. GPU内存是否足够
  4. 依赖包是否安装完整

扩展

要添加新的数据集支持,需要:

  1. discrete_graph_learning.py中添加数据集配置
  2. 创建对应的配置文件
  3. 确保数据文件格式正确