TrafficWheel/README.md

954 B
Executable File
Raw Blame History

Traffic Wheel 交通轮

依赖环境

支持python 3.10以上版本。

使用conda创建基本环境

conda create -n trafficwheel python=3.10

pip下载安装包

pip install -r requirements.txt

pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook

准备GPT预训练权重

需要海外网络,如果没有海外网络,手动下载后上传。

GPT-2文件夹内应该有两个文件{config.json, pytorch_model.bin}

mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin

跑REPST

第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。

python run.py --config ./config/REPST/PEMS-BAY.yaml