|
|
||
|---|---|---|
| .vscode | ||
| config | ||
| dataloader | ||
| model | ||
| trainer | ||
| utils | ||
| .gitignore | ||
| LICENSE | ||
| README.md | ||
| Result.md | ||
| Result.xlsx | ||
| generate_launch_configs.py | ||
| mypy.ini | ||
| requirements.txt | ||
| run.py | ||
| test_configs.py | ||
| test_results.txt | ||
| transfer_guide.md | ||
README.md
Traffic Wheel 交通轮
依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
conda create -n trafficwheel python=3.10
pip下载安装包
pip install -r requirements.txt
或
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
准备GPT预训练权重
需要海外网络,如果没有海外网络,手动下载后上传。
GPT-2文件夹内应该有两个文件:{config.json, pytorch_model.bin}
mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
跑REPST
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
python run.py --config ./config/REPST/PEMS-BAY.yaml