943 B
Executable File
943 B
Executable File
Traffic Wheel 交通轮
依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
conda create -n trafficwheel python=3.10
pip下载安装包
pip install -r requirements.txt
或
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
准备GPT预训练权重
需要海外网络,如果没有海外网络,手动下载后上传。
GPT-2文件夹内应该有两个文件:{config.json, pytorch_model.bin}
mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
跑REPST
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8。
python run.py --config ./config/REPST/PEMSD8.yaml