|
|
||
|---|---|---|
| .vscode | ||
| config | ||
| dataloader | ||
| lib | ||
| model | ||
| trainer | ||
| utils | ||
| .gitignore | ||
| LICENSE | ||
| README.md | ||
| Result.md | ||
| Result.xlsx | ||
| baseline.ipynb | ||
| baseline1.ipynb | ||
| requirements.txt | ||
| run.py | ||
| transfer_guide.md | ||
README.md
Traffic Wheel 交通轮
依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
conda create -n trafficwheel python=3.10
pip下载安装包
pip install -r requirements.txt
或
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
快速开始
参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel)
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
run.py会自动完成数据集下载、模型训练/评估工作。
⚠️现有的模型性能数据存放在Result.xlsx,不必浪费资源再跑一遍。
测试模型
在实验结束后,模型的存档文件会被保存在 experiments/dataset/训练时间 文件夹下,共有4个文件。
- best_model.pth 保存了使验证集效果最好的checkpoint
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
- DATASET.yaml 存放了训练模型时所使用的参数
- run.log 记录了训练日志。
可以创建pre-train/{dataset_name}文件夹,把整个文件夹的内容拷贝到experiments/dataset/训练时间 文件夹下的内容全部拷贝到pre-train/{dataset_name}里面。之后就可以在命令中调用 --model test进行测试。
⚠️注意,请及时删除experiments文件夹中不必要的文件,要不整个文件夹会越堆越大。
更改配置
在config文件夹中,存放了每个模型的配置文件。每个数据集单独配置,使用yaml格式。
你需要找到对应模型的参数进行修改。
配置文件分为五个部分:[data, model, train, test, log]
- data一般不用改,存放了模型的节点数,预测窗口,历史窗口等信息
- model中的参数要结合代码和论文看,一般会给出推荐配置。
- train调整模型的训练细节,包括batch size,学习率,batch_norm等。
一般不建议对基线参数进行修改,按默认跑是最稳定的。
开发模型
首先你需要创建一个开发分支dev,并切换到开发分支
git switch -c dev
参考 模型迁移教程 迁移模型到TrafficWheel中。
提交更改。
git add .
git commit -m "Commit message"
推送到远程仓库(需要找我注册账号)
git push origin dev
模型开发完成后,需要合并到main分支,在这里提交pull request。
已知问题
目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN
STGCN在载入图时会有未知warning
以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED
源代码及论文
| 论文 | 代码 |
|---|---|
| HierAttnLSTM | 代码 |
| DSANET | 代码 |
| STGCN | 代码 |
| DCRNN | 代码 |
| GraphWaveNet | 代码 |
| STSGCN | 代码 |
| AGCRN | 代码 |
| STFGNN | 代码 |
| STGODE | 代码 |
| STG-NCDE | 代码 |
| DDGCRN | 代码 |