TrafficWheel/README.md

130 lines
4.9 KiB
Markdown
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Traffic Wheel 交通轮
# 依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
```bash
conda create -n trafficwheel python=3.10
```
pip下载安装包
```bash
pip install -r requirements.txt
```
```bash
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
```
# 快速开始(暂时弃用)
参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel
```bash
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
```
- model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
- dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
- modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数
run.py会自动完成数据集下载、模型训练/评估工作。
:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。
# 测试模型
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下共有4个文件。
- best_model.pth 保存了使验证集效果最好的checkpoint
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
- DATASET.yaml 存放了训练模型时所使用的参数
- run.log 记录了训练日志。
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。
:warning:注意请及时删除experiments文件夹中不必要的文件要不整个文件夹会越堆越大。
# 更改配置
在config文件夹中存放了每个模型的配置文件。每个数据集单独配置使用yaml格式。
你需要找到对应模型的参数进行修改。
配置文件分为五个部分:[data, model, train, test, log]
- data一般不用改存放了模型的节点数预测窗口历史窗口等信息
- model中的参数要结合代码和论文看一般会给出推荐配置。
- train调整模型的训练细节包括batch size学习率batch_norm等。
一般不建议对基线参数进行修改,按默认跑是最稳定的。
# 开发模型
首先你需要创建一个开发分支dev并切换到开发分支
```bash
git switch -c dev
```
参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。
提交更改。
```bash
git add .
git commit -m "Commit message"
```
推送到远程仓库(需要找我注册账号)
```bash
git push origin dev
```
模型开发完成后需要合并到main分支在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。
# 已知问题
目前实测以下模型性能与原报告相比指标偏高ARIMA、TCN、DCRNN
STGCN在载入图时会有未知warning
以下模型由于没有源码暂未实现HA、VAR、FC-LSTM、GRU-ED
# 源代码及论文
| 论文 | 代码 |
| ------------------------------------------------------------ | ------------------------------------------------------------ |
| [HierAttnLSTM](https://arxiv.org/abs/2201.05760v4) | [代码](https://github.com/TeRyZh/Network-Level-Travel-Prediction-Hierarchical-Attention-LSTM) |
| [DSANET](https://dl.acm.org/doi/10.1145/3357384.3358132) | [代码](https://github.com/bighuang624/DSANet) |
| [STGCN](https://arxiv.org/abs/1709.04875) | [代码](https://github.com/hazdzz/STGCN) |
| [DCRNN](https://arxiv.org/abs/1707.01926) | [代码](https://github.com/chnsh/DCRNN_PyTorch) |
| [GraphWaveNet](https://arxiv.org/pdf/1906.00121.pdf) | [代码](https://github.com/SGT-LIM/GraphWavenet) |
| [STSGCN](https://aaai.org/ojs/index.php/AAAI/article/view/5438/5294) | [代码](https://github.com/SmallNana/STSGCN_Pytorch) |
| [AGCRN](https://arxiv.org/pdf/2007.02842) | [代码](https://github.com/LeiBAI/AGCRN) |
| [STFGNN](https://arxiv.org/abs/2012.09641) | [代码](https://github.com/lwm412/STFGNN-Pytorch) |
| [STGODE](https://arxiv.org/abs/2106.12931) | [代码](https://github.com/square-coder/STGODE) |
| [STG-NCDE](https://arxiv.org/abs/2112.03558) | [代码](https://github.com/jeongwhanchoi/STG-NCDE) |
| [DDGCRN](https://www.sciencedirect.com/science/article/abs/pii/S0031320323003710) | [代码](https://github.com/wengwenchao123/DDGCRN) |