交通轮,基线合集
Go to file
HengZhang 5df1c810b1 edit README 2025-03-03 13:19:03 +08:00
config add dl function 2025-03-03 11:34:36 +08:00
dataloader add dl function 2025-03-03 11:34:36 +08:00
lib add dl function 2025-03-03 11:34:36 +08:00
model mv dir name 2025-03-03 10:29:42 +08:00
trainer init 2025-03-02 23:41:12 +08:00
.gitignore add dl function 2025-03-03 11:34:36 +08:00
LICENSE Initial commit 2025-03-02 23:39:13 +08:00
README.md edit README 2025-03-03 13:19:03 +08:00
Result.xlsx edit README 2025-03-03 13:19:03 +08:00
baseline.ipynb init 2025-03-02 23:41:12 +08:00
run.py add dl function 2025-03-03 11:34:36 +08:00

README.md

Traffic Wheel 交通轮

依赖环境

支持python 3.10以上版本。

使用conda创建基本环境

conda create -n trafficwheel python=3.10

pip下载安装包

pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook

快速开始

参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel

python run.py --model {model_name} --dataset {dataset_name}  --mode {train, test}  --device {cuda:0}
  • model_name: 目前支持DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN
  • dataset_name目前支持PEMSD3PEMSD4、PEMSD7、PEMSD8
  • modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
  • device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数

测试模型

在实验结束后,模型的存档文件会被保存在 experiments/dataset/训练时间 文件夹下共有4个文件。

  • best_model.pth 保存了使验证集效果最好的checkpoint
  • best_test_model.pth 保存了使测试集的效果最好的checkpoint
  • DATASET.yaml 存放了训练模型时所使用的参数
  • run.log 记录了训练日志。

可以创建pre-train/{dataset_name}文件夹,把整个文件夹的内容拷贝到experiments/dataset/训练时间 文件夹下的内容全部拷贝到pre-train/{dataset_name}里面。之后就可以在命令中调用 --model test进行测试。