TrafficWheel/README.md

42 lines
1.5 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Traffic Wheel 交通轮
# 依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
```
conda create -n trafficwheel python=3.10
```
pip下载安装包
```
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
```
# 快速开始
参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel
```
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
```
- model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
- dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
- modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数
# 测试模型
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下共有4个文件。
- best_model.pth 保存了使验证集效果最好的checkpoint
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
- DATASET.yaml 存放了训练模型时所使用的参数
- run.log 记录了训练日志。
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。