diff --git a/README.md b/README.md index 4db6252..13faf9d 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,19 @@ 使用conda创建基本环境 -``` +```bash conda create -n trafficwheel python=3.10 ``` pip下载安装包 +```bash +pip install -r requirements.txt ``` + +或 + +```bash pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook ``` @@ -20,7 +26,7 @@ pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio 参考baseline.ipynb中的命令执行,或者使用下面的命令:(请确保当前目录为TrafficWheel) -``` +```bash python run.py --model {model_name} --dataset {dataset_name} --mode {train, test} --device {cuda:0} ``` @@ -29,6 +35,10 @@ python run.py --model {model_name} --dataset {dataset_name} --mode {train, test - mode:train为训练模型,test为测试模型。测试模型需要在pre-train文件中找到模型的pth存档。 - device: 支持'cpu'、'cuda:0'、‘cuda:1’ ... 取决于机器卡数 +run.py会自动完成数据集下载、模型训练/评估工作。 + +:warning:现有的模型性能数据存放在[Result.xlsx](./Result.xlsx),不必浪费资源再跑一遍。 + # 测试模型 在实验结束后,模型的存档文件会被保存在 `experiments/dataset/训练时间 `文件夹下,共有4个文件。 @@ -39,3 +49,51 @@ python run.py --model {model_name} --dataset {dataset_name} --mode {train, test - run.log 记录了训练日志。 可以创建`pre-train/{dataset_name}`文件夹,把整个文件夹的内容拷贝到`experiments/dataset/训练时间 `文件夹下的内容全部拷贝到`pre-train/{dataset_name}`里面。之后就可以在命令中调用` --model test`进行测试。 + +:warning:注意,请及时删除experiments文件夹中不必要的文件,要不整个文件夹会越堆越大。 + +# 更改配置 + +在config文件夹中,存放了每个模型的配置文件。每个数据集单独配置,使用yaml格式。 + +你需要找到对应模型的参数进行修改。 + +配置文件分为五个部分:[data, model, train, test, log] + +- data一般不用改,存放了模型的节点数,预测窗口,历史窗口等信息 +- model中的参数要结合代码和论文看,一般会给出推荐配置。 +- train调整模型的训练细节,包括batch size,学习率,batch_norm等。 + +一般不建议对基线参数进行修改,按默认跑是最稳定的。 + +# 开发模型 + +首先你需要创建一个开发分支dev,并切换到开发分支 + +```bash +git switch -c dev +``` + +参考 [模型迁移教程](./transfer_guide.md) 迁移模型到TrafficWheel中。 + +提交更改。 + +```bash +git add . +git commit -m "Commit message" +``` + +推送到远程仓库(需要找我注册账号) + +```bash +git push origin dev +``` + +模型开发完成后,需要合并到main分支,在[这里](https://github.zhang-heng.com/czzhangheng/TrafficWheel/pulls)提交pull request。 + +# 已知问题 + +目前,实测以下模型性能与原报告相比指标偏高:ARIMA、TCN、DCRNN + +以下模型由于没有源码暂未实现:HA、VAR、FC-LSTM、GRU-ED + diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..313c036 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +pyyaml +tqdm +statsmodels +h5py +kagglehub +torch +torchvision +torchaudio +torchdiffeq +fastdtw +notebook \ No newline at end of file diff --git a/transfer_guide.md b/transfer_guide.md new file mode 100644 index 0000000..e69de29