交通轮,基线合集
Go to file
czzhangheng 5a7ec07a12 Merge pull request 'REPST' (#3) from REPST into main
Reviewed-on: #3
2025-12-20 16:03:22 +08:00
.vscode impl mtgnn 2025-12-10 23:31:17 +08:00
config impl STNorm 2025-12-20 15:45:13 +08:00
dataloader refactor(dataloader): 重构数据加载器代码,优化滑动窗口生成和归一化处理 2025-12-15 20:54:20 +08:00
model impl STNorm 2025-12-20 15:45:13 +08:00
trainer impl FPT 2025-12-19 10:19:17 +08:00
utils refactor: 重构数据加载器和训练器代码,优化代码结构和可读性 2025-12-15 01:38:47 +08:00
.gitignore ignore 2025-12-01 20:48:46 +08:00
LICENSE 收到垃圾堆 2025-04-22 14:54:39 +08:00
README.md 修复confuse_layer硬编码bug 2025-11-11 17:26:05 +08:00
Result.md 添加md结果 2025-04-26 15:37:40 +08:00
Result.xlsx 更新.gitignore以忽略Result.xlsx文件,修改run.py以优先使用macOS的MPS设备,优化设备设置逻辑 2025-08-18 14:32:20 +08:00
requirements.txt 实现iTransformer 2025-12-09 16:11:49 +08:00
run.py 实现iTransformer 2025-12-09 16:11:49 +08:00
train.py impl STNorm 2025-12-20 15:45:13 +08:00
transfer_guide.md 解决合并冲突,整合dev和main分支的更改 2025-05-14 13:13:11 +08:00

README.md

Traffic Wheel 交通轮

依赖环境

支持python 3.10以上版本。

使用conda创建基本环境

conda create -n trafficwheel python=3.10

pip下载安装包

pip install -r requirements.txt

pip install pyyaml tqdm statsmodels h5py kagglehub torch torchvision torchaudio torchdiffeq fastdtw notebook

准备GPT预训练权重

需要海外网络,如果没有海外网络,手动下载后上传。

GPT-2文件夹内应该有两个文件{config.json, pytorch_model.bin}

mkdir GPT-2
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./GPT-2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./GPT-2/pytorch_model.bin

跑REPST

第一遍跑时程序会自动下载数据集。目前仅支持PEMSD8/PEMS-BAY。

python run.py --config ./config/REPST/PEMS-BAY.yaml