45 lines
954 B
Markdown
Executable File
45 lines
954 B
Markdown
Executable File
# Traffic Wheel 交通轮
|
||
|
||
# 依赖环境
|
||
|
||
支持python 3.10以上版本。
|
||
|
||
使用conda创建基本环境
|
||
|
||
```bash
|
||
conda create -n trafficwheel python=3.10
|
||
```
|
||
|
||
pip下载安装包
|
||
|
||
```bash
|
||
pip install -r requirements.txt
|
||
```
|
||
|
||
或
|
||
|
||
```bash
|
||
pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook
|
||
```
|
||
|
||
# 准备GPT预训练权重
|
||
|
||
需要海外网络,如果没有海外网络,手动下载后上传。
|
||
|
||
GPT-2文件夹内应该有两个文件:`{config.json, pytorch_model.bin}`
|
||
|
||
```bash
|
||
mkdir GPT-2
|
||
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
|
||
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
|
||
```
|
||
|
||
|
||
# 跑REPST
|
||
第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。
|
||
|
||
```bash
|
||
python run.py --config ./config/REPST/PEMS-BAY.yaml
|
||
```
|
||
|