TrafficWheel/README.md

45 lines
954 B
Markdown
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Traffic Wheel 交通轮
# 依赖环境
支持python 3.10以上版本。
使用conda创建基本环境
```bash
conda create -n trafficwheel python=3.10
```
pip下载安装包
```bash
pip install -r requirements.txt
```
```bash
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
```
# 准备GPT预训练权重
需要海外网络,如果没有海外网络,手动下载后上传。
GPT-2文件夹内应该有两个文件`{config.json, pytorch_model.bin}`
```bash
mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
```
# 跑REPST
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
```bash
python run.py --config ./config/REPST/PEMS-BAY.yaml
```