diff --git a/README.md b/README.md index c123b8f..4db6252 100644 --- a/README.md +++ b/README.md @@ -1,32 +1,41 @@ -依赖包 +# Traffic Wheel 交通轮 + +# 依赖环境 + 支持python 3.10以上版本。 +使用conda创建基本环境 + +``` conda create -n trafficwheel python=3.10 +``` -pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw +pip下载安装包 +``` +pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook +``` +# 快速开始 -命令 +参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel) -![image-20241214230153502](./assets/image-20241214230153502.png) +``` +python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} +``` +- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN +- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8 +- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 +- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数 +# 测试模型 -添加模型: +在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。 -三步:1. 在config下新建文件夹,复制其他模型的参数项,改命令 +- best_model.pth 保存了使验证集效果最好的checkpoint +- best_test_model.pth 保存了使测试集的效果最好的checkpoint +- DATASET.yaml 存放了训练模型时所使用的参数 +- run.log 记录了训练日志。 -2. 在model新建文件夹,复制模型文件 - -![image-20241214230303239](./assets/image-20241214230303239.png) - -使用 arg['参数名']访问参数,具体参数在yaml文件中的model类下对应配置,有啥写啥。一般只要这里的参数就可以了,不需要动其他的train,data - -![image-20241214230331273](./assets/image-20241214230331273.png) - -第三步:在model/model_selector下添加自己的模型,类似这种格式 - -![image-20241214230447678](./assets/image-20241214230447678.png) - -然后就可以运行了。其中,ARIMA,VAR还没做好,不要运行。 +可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。 diff --git a/Result.xlsx b/Result.xlsx index 62bce8b..cd78009 100644 Binary files a/Result.xlsx and b/Result.xlsx differ diff --git a/assets/image-20241214230153502.png b/assets/image-20241214230153502.png deleted file mode 100644 index e2176e9..0000000 Binary files a/assets/image-20241214230153502.png and /dev/null differ diff --git a/assets/image-20241214230303239.png b/assets/image-20241214230303239.png deleted file mode 100644 index cd36ac7..0000000 Binary files a/assets/image-20241214230303239.png and /dev/null differ diff --git a/assets/image-20241214230331273.png b/assets/image-20241214230331273.png deleted file mode 100644 index d3ad2ba..0000000 Binary files a/assets/image-20241214230331273.png and /dev/null differ diff --git a/assets/image-20241214230447678.png b/assets/image-20241214230447678.png deleted file mode 100644 index 43aa779..0000000 Binary files a/assets/image-20241214230447678.png and /dev/null differ