update Readme Add transfer_guide.md

This commit is contained in:
HengZhang 2025-03-03 13:42:08 +08:00
parent 5df1c810b1
commit 01782a0c81
3 changed files with 71 additions and 2 deletions

View File

@ -6,13 +6,19 @@
使用conda创建基本环境 使用conda创建基本环境
``` ```bash
conda create -n trafficwheel python=3.10 conda create -n trafficwheel python=3.10
``` ```
pip下载安装包 pip下载安装包
```bash
pip install -r requirements.txt
``` ```
```bash
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
``` ```
@ -20,7 +26,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio
参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel 参考baseline.ipynb中的命令执行或者使用下面的命令请确保当前目录为TrafficWheel
``` ```bash
python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0}
``` ```
@ -29,6 +35,10 @@ python run.py --model {model_name} --dataset {dataset_name} --mode {train, test
- modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 - modetrain为训练模型test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。
- device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数 - device: 支持'cpu'、'cuda:0'、cuda:1 ... 取决于机器卡数
run.py会自动完成数据集下载、模型训练/评估工作。
:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。
# 测试模型 # 测试模型
在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下共有4个文件。 在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下共有4个文件。
@ -39,3 +49,51 @@ python run.py --model {model_name} --dataset {dataset_name} --mode {train, test
- run.log 记录了训练日志。 - run.log 记录了训练日志。
可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。 可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。
:warning:注意请及时删除experiments文件夹中不必要的文件要不整个文件夹会越堆越大。
# 更改配置
在config文件夹中存放了每个模型的配置文件。每个数据集单独配置使用yaml格式。
你需要找到对应模型的参数进行修改。
配置文件分为五个部分:[data, model, train, test, log]
- data一般不用改存放了模型的节点数预测窗口历史窗口等信息
- model中的参数要结合代码和论文看一般会给出推荐配置。
- train调整模型的训练细节包括batch size学习率batch_norm等。
一般不建议对基线参数进行修改,按默认跑是最稳定的。
# 开发模型
首先你需要创建一个开发分支dev并切换到开发分支
```bash
git switch -c dev
```
参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。
提交更改。
```bash
git add .
git commit -m "Commit message"
```
推送到远程仓库(需要找我注册账号)
```bash
git push origin dev
```
模型开发完成后需要合并到main分支在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。
# 已知问题
目前实测以下模型性能与原报告相比指标偏高ARIMA、TCN、DCRNN
以下模型由于没有源码暂未实现HA、VAR、FC-LSTM、GRU-ED

11
requirements.txt Normal file
View File

@ -0,0 +1,11 @@
pyyaml
tqdm
statsmodels
h5py
kagglehub
torch
torchvision
torchaudio
torchdiffeq
fastdtw
notebook

0
transfer_guide.md Normal file
View File