42 lines
1.5 KiB
Markdown
42 lines
1.5 KiB
Markdown
# Traffic Wheel 交通轮
|
||
|
||
# 依赖环境
|
||
|
||
支持python 3.10以上版本。
|
||
|
||
使用conda创建基本环境
|
||
|
||
```
|
||
conda create -n trafficwheel python=3.10
|
||
```
|
||
|
||
pip下载安装包
|
||
|
||
```
|
||
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
|
||
```
|
||
|
||
# 快速开始
|
||
|
||
参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel)
|
||
|
||
```
|
||
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
|
||
```
|
||
|
||
- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
|
||
- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8
|
||
- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
|
||
- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数
|
||
|
||
# 测试模型
|
||
|
||
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。
|
||
|
||
- best_model.pth 保存了使验证集效果最好的checkpoint
|
||
- best_test_model.pth 保存了使测试集的效果最好的checkpoint
|
||
- DATASET.yaml 存放了训练模型时所使用的参数
|
||
- run.log 记录了训练日志。
|
||
|
||
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。
|