问题分析: 1. 参数量异常小(16,522) - 缺少node到edge转换层 2. 维度错误 - 编码器期望edge格式但收到node格式输入 3. 解码器维度计算错误 修复内容: - 添加node_to_edge和edge_to_node转换层,参数量从16,522增加到1,009,002 - 修改forward方法正确处理node格式输入输出 - 修复编码器以处理edge格式的中间数据 - 修正解码器中的维度计算问题 测试结果: - 参数量:1,009,002 (合理范围) - 输入输出形状正确:(batch_size, seq_len/horizon, num_nodes, input/output_dim) - 模型可以正常前向传播 |
||
|---|---|---|
| configs | ||
| data | ||
| models | ||
| trainer | ||
| utils | ||
| .gitignore | ||
| LICENSE | ||
| README.md | ||
| main.py | ||
| requirements.txt | ||
README.md
Project-I
Secret Projct
mkdir -p models/gpt2
Prepare Env.
pip install -r requirement.txt
Download dataset
python utils/download.py
Download gpt weight
mkdir -p models/gpt2
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
Use pytorch >= 2.6 to load model.
Run
python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml