|
|
||
|---|---|---|
| data_provider | ||
| figures | ||
| models | ||
| scripts | ||
| utils | ||
| .gitignore | ||
| README.md | ||
| prepare_pems_bay.py | ||
| requirements.txt | ||
| run.py | ||
README.md
RePST 修复版
准备GPT-2预训练权重
mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin
准备PEMS-BAY数据集,按照BasicTS方法准备
Google Drive 可使用gdown下载。
解压后,确保 ./datasets/PEMS-BAY 文件夹内具有 adj_mx.pkl, data.dat, desc,json文件, 然后运行脚本
python prepare_pems_bay.py
在PEMS-BAY数据集文件夹下生成 train.npz, val.npz, test.npz
根据BasicTS仓库配置BasicTS环境,亦或是使用
pip install -r requirements.txt
我是直接使用现有的BasicTS环境,因此没有做过测试
创建log文件夹到项目仓库
mkdir log
开跑
python run.py --root_path datasets --data_path PEMS-BAY --device cuda:0 --seq_len 12 --pred_len 12