优化READMER,注册STGODE-LLM

This commit is contained in:
harry.zhang 2025-09-04 06:05:09 +00:00
parent 8ee9b4cfb2
commit 70da9574da
7 changed files with 73 additions and 29 deletions

View File

@ -2,7 +2,7 @@
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="jdk" jdkName="I" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">

View File

@ -3,5 +3,5 @@
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
<component name="ProjectRootManager" version="2" project-jdk-name="I" project-jdk-type="Python SDK" />
</project>

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

View File

@ -5,10 +5,10 @@
</component>
<component name="ChangeListManager">
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/Project-I.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/Project-I.iml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
@ -38,7 +38,7 @@
</excluded-from-favorite>
<option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="main" />
<entry key="$PROJECT_DIR$" value="STGODE" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
@ -58,10 +58,13 @@
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"Python.STDEN.executor": "Debug",
"Python.STGODE (1).executor": "Run",
"Python.STGODE.executor": "Run",
"Python.main.executor": "Run",
"RunOnceActivity.OpenProjectViewOnStart": "true",
"RunOnceActivity.ShowReadmeOnStart": "true",
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
"RunOnceActivity.git.unshallow": "true",
"git-widget-placeholder": "main",
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
"node.js.detected.package.eslint": "true",
@ -69,6 +72,7 @@
"node.js.selected.package.eslint": "(autodetect)",
"node.js.selected.package.tslint": "(autodetect)",
"nodejs_package_manager_path": "npm",
"settings.editor.selected.configurable": "project.propVCSSupport.DirectoryMappings",
"vue.rearranger.settings.migration": "true"
}
}]]></component>
@ -77,10 +81,11 @@
<window_info id="Space Code Reviews" show_stripe_button="false" />
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
<window_info id="Merge Requests" show_stripe_button="false" />
<window_info id="Backup and Sync History" />
<window_info id="Commit_Guest" show_stripe_button="false" />
<window_info id="Pull Requests" show_stripe_button="false" />
<window_info id="Learn" show_stripe_button="false" />
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27321428" />
<window_info id="Commit" order="1" weight="0.25" />
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
@ -88,17 +93,20 @@
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
<window_info anchor="bottom" id="HfCacheToolWindow" />
<window_info anchor="bottom" id="Find" />
<window_info anchor="bottom" id="Version Control" order="0" />
<window_info anchor="bottom" id="Problems" order="1" />
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
<window_info active="true" anchor="bottom" id="Services" order="4" visible="true" weight="0.32989067" />
<window_info anchor="bottom" id="Terminal" order="3" weight="0.3296646" />
<window_info anchor="bottom" id="Services" order="4" />
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
<window_info active="true" anchor="bottom" id="Run" order="7" visible="true" weight="0.3296646" />
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
<window_info anchor="right" id="Coverage" side_tool="true" />
<window_info anchor="right" id="SciView" show_stripe_button="false" />
<window_info anchor="right" id="Jupyter Variables" side_tool="true" />
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
<window_info anchor="right" id="Database" order="2" weight="0.25" />
@ -117,7 +125,7 @@
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
</component>
<component name="RunManager" selected="Python.STGODE">
<component name="RunManager" selected="Python.STGODE (1)">
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
@ -142,6 +150,29 @@
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="STGODE (1)" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="true" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STGODE_LLM_GPT2/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
@ -169,12 +200,14 @@
<list>
<item itemvalue="Python.STDEN" />
<item itemvalue="Python.STGODE" />
<item itemvalue="Python.STGODE (1)" />
</list>
</component>
<component name="SharedIndexes">
<attachedChunks>
<set>
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
<option value="bundled-js-predefined-d6986cc7102b-e03c56caf84a-JavaScript-PY-252.23892.515" />
<option value="bundled-python-sdk-7e47963ff851-f0eec537fc84-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.23892.515" />
</set>
</attachedChunks>
</component>
@ -189,6 +222,8 @@
<workItem from="1756727623101" duration="4721000" />
<workItem from="1756856673845" duration="652000" />
<workItem from="1756864144998" duration="1063000" />
<workItem from="1756960597140" duration="1062000" />
<workItem from="1756965151878" duration="505000" />
</task>
<servers />
</component>
@ -207,8 +242,9 @@
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE_LLM.coverage" NAME="STGODE-LLM 覆盖结果" MODIFIED="1756950739801" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756885209907" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE__1_.coverage" NAME="STGODE (1) 覆盖结果" MODIFIED="1756965216400" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
</component>
</project>

View File

@ -2,23 +2,34 @@
Secret Projct
```
mkdir -p models/gpt2
```
## Prepare Env.
```
pip install -r requirement.txt
```
## Download dataset
```
python utils/download.py
```
## Download gpt weight
mkdir -p models/gpt2
`mkdir -p models/gpt2`
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
```bash
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
```
Use pytorch >= 2.6 to load model.
## Run
Run: `python.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
```
`python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
```

View File

@ -66,11 +66,11 @@ class GPT2BackboneHF(nn.Module):
super().__init__()
from transformers import GPT2Model
if local_dir is not None and len(local_dir) > 0:
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True)
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False)
else:
if model_name is None:
model_name = 'gpt2'
self.model = GPT2Model.from_pretrained(model_name)
self.model = GPT2Model.from_pretrained(model_name, use_cache=False)
if gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.hidden_size = self.model.config.hidden_size

View File

@ -1,12 +1,16 @@
from models.STDEN.stden_model import STDENModel
from models.STGODE.STGODE import ODEGCN
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
case 'STDEN':
case 'STDEN':
model = STDENModel(config)
case 'STGODE':
model = ODEGCN(config)
return model
case 'STGODE-LLM-GPT2':
model = ODEGCN_LLM_GPT2(config)
return model