From 2800f66dfe62775b710266e7e7aa4f98f3fc5d2b Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 9 Nov 2025 20:43:00 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=A0=E9=99=A4=E5=86=97=E5=8F=91=E4=BD=99?= =?UTF-8?q?=E4=BB=A3=E7=A0=81=E3=80=82=E9=87=8D=E6=9E=84=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 103 +++++------------------------------------------------- 1 file changed, 9 insertions(+), 94 deletions(-) diff --git a/README.md b/README.md index 7071e33..2ee8496 100755 --- a/README.md +++ b/README.md @@ -22,108 +22,23 @@ pip install -r requirements.txt pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook ``` +# 准备GPT预训练权重 +需要海外网络,如果没有海外网络,手动下载后上传。 -# 快速开始(暂时弃用) - -参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel) +GPT-2文件夹内应该有两个文件:`{config.json, pytorch_model.bin}` ```bash -python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} +mkdir GPT-2 +wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json +wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin ``` -- model_name: 目前支持:DSANET、STGCN、DCRNN、 GWN(GraphWaveNet)、STSGCN、AGCRN、STFGNN、STGODE、STGNCDE、DDGCRN、TWDGCN、STAWnet -- dataset_name目前支持:PEMSD3,PEMSD4、PEMSD7、PEMSD8 -- mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 -- device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数 -run.py会自动完成数据集下载、模型训练/评估工作。 - -:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。 - - - -# 测试模型 - -在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。 - -- best_model.pth 保存了使验证集效果最好的checkpoint -- best_test_model.pth 保存了使测试集的效果最好的checkpoint -- DATASET.yaml 存放了训练模型时所使用的参数 -- run.log 记录了训练日志。 - -可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。 - -:warning:注意,请及时删除experiments文件夹中不必要的文件,要不整个文件夹会越堆越大。 - - - -# 更改配置 - -在config文件夹中,存放了每个模型的配置文件。每个数据集单独配置,使用yaml格式。 - -你需要找到对应模型的参数进行修改。 - -配置文件分为五个部分:[data, model, train, test, log] - -- data一般不用改,存放了模型的节点数,预测窗口,历史窗口等信息 -- model中的参数要结合代码和论文看,一般会给出推荐配置。 -- train调整模型的训练细节,包括batch size,学习率,batch_norm等。 - -一般不建议对基线参数进行修改,按默认跑是最稳定的。 - - - -# 开发模型 - -首先你需要创建一个开发分支dev,并切换到开发分支 +# 跑REPST +第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8。 ```bash -git switch -c dev +python run.py --config ./config/REPST/PEMSD8.yaml ``` -参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。 - -提交更改。 - -```bash -git add . -git commit -m "Commit message" -``` - -推送到远程仓库(需要找我注册账号) - -```bash -git push origin dev -``` - -模型开发完成后,需要合并到main分支,在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。 - - - -# 已知问题 - -目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN - -STGCN在载入图时会有未知warning - -以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED - - - -# 源代码及论文 - -| 论文 | 代码 | -| ------------------------------------------------------------ | ------------------------------------------------------------ | -| [HierAttnLSTM](https://arxiv.org/abs/2201.05760v4) | [代码](https://github.com/TeRyZh/Network-Level-Travel-Prediction-Hierarchical-Attention-LSTM) | -| [DSANET](https://dl.acm.org/doi/10.1145/3357384.3358132) | [代码](https://github.com/bighuang624/DSANet) | -| [STGCN](https://arxiv.org/abs/1709.04875) | [代码](https://github.com/hazdzz/STGCN) | -| [DCRNN](https://arxiv.org/abs/1707.01926) | [代码](https://github.com/chnsh/DCRNN_PyTorch) | -| [GraphWaveNet](https://arxiv.org/pdf/1906.00121.pdf) | [代码](https://github.com/SGT-LIM/GraphWavenet) | -| [STSGCN](https://aaai.org/ojs/index.php/AAAI/article/view/5438/5294) | [代码](https://github.com/SmallNana/STSGCN_Pytorch) | -| [AGCRN](https://arxiv.org/pdf/2007.02842) | [代码](https://github.com/LeiBAI/AGCRN) | -| [STFGNN](https://arxiv.org/abs/2012.09641) | [代码](https://github.com/lwm412/STFGNN-Pytorch) | -| [STGODE](https://arxiv.org/abs/2106.12931) | [代码](https://github.com/square-coder/STGODE) | -| [STG-NCDE](https://arxiv.org/abs/2112.03558) | [代码](https://github.com/jeongwhanchoi/STG-NCDE) | -| [DDGCRN](https://www.sciencedirect.com/science/article/abs/pii/S0031320323003710) | [代码](https://github.com/wengwenchao123/DDGCRN) | -