- 在STDENModel中添加node_to_edge和edge_to_node转换层 - 修改forward方法以处理node_num输入并输出node_num格式 - 更新编码器以处理edge格式的中间数据 - 修复解码器中的维度计算问题 - 解决设备不匹配和数据类型不一致问题 - 更新.gitignore以允许models/STDEN/代码目录被跟踪 现在模型可以接受node_num格式的输入,内部转换为edge_num进行处理,最后转换回node_num输出。 |
||
|---|---|---|
| configs | ||
| data | ||
| models | ||
| trainer | ||
| utils | ||
| .gitignore | ||
| LICENSE | ||
| README.md | ||
| main.py | ||
| requirements.txt | ||
README.md
Project-I
Secret Projct
mkdir -p models/gpt2
Prepare Env.
pip install -r requirement.txt
Download dataset
python utils/download.py
Download gpt weight
mkdir -p models/gpt2
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
Use pytorch >= 2.6 to load model.
Run
python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml