交通轮,基线合集
Go to file
czzhangheng 1be0b59344 新增EXP数据加载器、模型和训练器,支持周期性数据处理和动态图构建 2025-05-08 22:43:26 +08:00
config 收到垃圾堆 2025-04-22 14:54:39 +08:00
dataloader 新增EXP数据加载器、模型和训练器,支持周期性数据处理和动态图构建 2025-05-08 22:43:26 +08:00
lib 收到垃圾堆 2025-04-22 14:54:39 +08:00
model 新增EXP数据加载器、模型和训练器,支持周期性数据处理和动态图构建 2025-05-08 22:43:26 +08:00
trainer 新增EXP数据加载器、模型和训练器,支持周期性数据处理和动态图构建 2025-05-08 22:43:26 +08:00
utils 收到垃圾堆 2025-04-22 14:54:39 +08:00
.gitignore 收到垃圾堆 2025-04-22 14:54:39 +08:00
LICENSE 收到垃圾堆 2025-04-22 14:54:39 +08:00
README.md 收到垃圾堆 2025-04-22 14:54:39 +08:00
Result.xlsx 收到垃圾堆 2025-04-22 14:54:39 +08:00
baseline.ipynb 收到垃圾堆 2025-04-22 14:54:39 +08:00
baseline1.ipynb 收到垃圾堆 2025-04-22 14:54:39 +08:00
requirements.txt 收到垃圾堆 2025-04-22 14:54:39 +08:00
run.py 收到垃圾堆 2025-04-22 14:54:39 +08:00
transfer_guide.md 收到垃圾堆 2025-04-22 14:54:39 +08:00

README.md

Traffic Wheel 交通轮

依赖环境

支持python 3.10以上版本。

使用conda创建基本环境

conda create -n trafficwheel python=3.10

pip下载安装包

pip install -r requirements.txt

pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook

快速开始

参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel

python run.py --model {model_name} --dataset {dataset_name}  --mode {train, test}  --device {cuda:0}
  • model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
  • dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
  • modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
  • device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数

run.py会自动完成数据集下载、模型训练/评估工作。

⚠️现有的模型性能数据存放在Result.xlsx,不必浪费资源再跑一遍。

测试模型

在实验结束后,模型的存档文件会被保存在 experiments/dataset/训练时间 文件夹下共有4个文件。

  • best_model.pth 保存了使验证集效果最好的checkpoint
  • best_test_model.pth 保存了使测试集的效果最好的checkpoint
  • DATASET.yaml 存放了训练模型时所使用的参数
  • run.log 记录了训练日志。

可以创建pre-train/{dataset_name}文件夹,把整个文件夹的内容拷贝到experiments/dataset/训练时间 文件夹下的内容全部拷贝到pre-train/{dataset_name}里面。之后就可以在命令中调用 --model test进行测试。

⚠️注意请及时删除experiments文件夹中不必要的文件要不整个文件夹会越堆越大。

更改配置

在config文件夹中存放了每个模型的配置文件。每个数据集单独配置使用yaml格式。

你需要找到对应模型的参数进行修改。

配置文件分为五个部分:[data, model, train, test, log]

  • data一般不用改存放了模型的节点数预测窗口历史窗口等信息
  • model中的参数要结合代码和论文看一般会给出推荐配置。
  • train调整模型的训练细节包括batch size学习率batch_norm等。

一般不建议对基线参数进行修改,按默认跑是最稳定的。

开发模型

首先你需要创建一个开发分支dev并切换到开发分支

git switch -c dev

参考 模型迁移教程 迁移模型到TrafficWheel中。

提交更改。

git add .
git commit -m "Commit message"

推送到远程仓库(需要找我注册账号)

git push origin dev

模型开发完成后需要合并到main分支这里提交pull request。

已知问题

目前实测以下模型性能与原报告相比指标偏高ARIMA、TCN、DCRNN

STGCN在载入图时会有未知warning

以下模型由于没有源码暂未实现HA、VAR、FC-LSTM、GRU-ED

源代码及论文

论文 代码
HierAttnLSTM 代码
DSANET 代码
STGCN 代码
DCRNN 代码
GraphWaveNet 代码
STSGCN 代码
AGCRN 代码
STFGNN 代码
STGODE 代码
STG-NCDE 代码
DDGCRN 代码