交通轮,基线合集
Go to file
czzhangheng 9b1cf5f0ce 更新测试脚本 2025-12-03 18:11:06 +08:00
.vscode 测试所有配置是否正常运行 2025-12-01 22:29:52 +08:00
config 兼容STAEFormer 2025-12-03 17:12:46 +08:00
dataloader 兼容NYCBike 2025-11-23 21:24:10 +08:00
model 兼容STAEFormer 2025-12-03 17:12:46 +08:00
trainer 适配GraphWaveNet 2025-12-03 12:05:02 +08:00
utils trainer修改 2025-12-01 21:36:37 +08:00
.gitignore 为其他模型添加配置文件 2025-12-01 20:45:30 +08:00
LICENSE 收到垃圾堆 2025-04-22 14:54:39 +08:00
README.md 修复confuse_layer硬编码bug 2025-11-11 17:26:05 +08:00
Result.md 添加md结果 2025-04-26 15:37:40 +08:00
Result.xlsx 更新.gitignore以忽略Result.xlsx文件,修改run.py以优先使用macOS的MPS设备,优化设备设置逻辑 2025-08-18 14:32:20 +08:00
generate_launch_configs.py 测试所有配置是否正常运行 2025-12-01 22:29:52 +08:00
mypy.ini trainer修改 2025-12-01 21:36:37 +08:00
requirements.txt 更新pip, STID配置 2025-11-24 09:42:33 +08:00
run.py 兼容BeijingAirQuality。重构data,需要更新pip requirement 2025-11-20 20:19:17 +08:00
run_tests.sh 更新测试脚本 2025-12-03 18:11:06 +08:00
test_results.txt add astra-pemsbay v2 2025-12-02 09:40:24 +08:00
transfer_guide.md 解决合并冲突,整合dev和main分支的更改 2025-05-14 13:13:11 +08:00

README.md

Traffic Wheel 交通轮

依赖环境

支持python 3.10以上版本。

使用conda创建基本环境

conda create -n trafficwheel python=3.10

pip下载安装包

pip install -r requirements.txt

pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook

准备GPT预训练权重

需要海外网络,如果没有海外网络,手动下载后上传。

GPT-2文件夹内应该有两个文件{config.json, pytorch_model.bin}

mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin

跑REPST

第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。

python run.py --config ./config/REPST/PEMS-BAY.yaml