Compare commits
3 Commits
0bf678f290
...
70da9574da
| Author | SHA1 | Date |
|---|---|---|
|
|
70da9574da | |
|
|
8ee9b4cfb2 | |
|
|
d62319302f |
|
|
@ -165,4 +165,6 @@ cython_debug/
|
|||
exp/
|
||||
STDEN/
|
||||
models/gpt2/
|
||||
pre-trained/
|
||||
pre-trained/
|
||||
data/*
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@
|
|||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||
<orderEntry type="jdk" jdkName="I" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
|
|
|
|||
|
|
@ -3,5 +3,5 @@
|
|||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="I" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -5,10 +5,10 @@
|
|||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
||||
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/Project-I.iml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/Project-I.iml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/misc.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/misc.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
|
|
@ -38,7 +38,7 @@
|
|||
</excluded-from-favorite>
|
||||
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
||||
<map>
|
||||
<entry key="$PROJECT_DIR$" value="main" />
|
||||
<entry key="$PROJECT_DIR$" value="STGODE" />
|
||||
</map>
|
||||
</option>
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
|
|
@ -58,10 +58,13 @@
|
|||
<component name="PropertiesComponent"><![CDATA[{
|
||||
"keyToString": {
|
||||
"Python.STDEN.executor": "Debug",
|
||||
"Python.STGODE (1).executor": "Run",
|
||||
"Python.STGODE.executor": "Run",
|
||||
"Python.main.executor": "Run",
|
||||
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
|
||||
"RunOnceActivity.git.unshallow": "true",
|
||||
"git-widget-placeholder": "main",
|
||||
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
||||
"node.js.detected.package.eslint": "true",
|
||||
|
|
@ -69,6 +72,7 @@
|
|||
"node.js.selected.package.eslint": "(autodetect)",
|
||||
"node.js.selected.package.tslint": "(autodetect)",
|
||||
"nodejs_package_manager_path": "npm",
|
||||
"settings.editor.selected.configurable": "project.propVCSSupport.DirectoryMappings",
|
||||
"vue.rearranger.settings.migration": "true"
|
||||
}
|
||||
}]]></component>
|
||||
|
|
@ -77,10 +81,11 @@
|
|||
<window_info id="Space Code Reviews" show_stripe_button="false" />
|
||||
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
|
||||
<window_info id="Merge Requests" show_stripe_button="false" />
|
||||
<window_info id="Backup and Sync History" />
|
||||
<window_info id="Commit_Guest" show_stripe_button="false" />
|
||||
<window_info id="Pull Requests" show_stripe_button="false" />
|
||||
<window_info id="Learn" show_stripe_button="false" />
|
||||
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
|
||||
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27321428" />
|
||||
<window_info id="Commit" order="1" weight="0.25" />
|
||||
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
||||
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
|
||||
|
|
@ -88,17 +93,20 @@
|
|||
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
||||
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
||||
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="HfCacheToolWindow" />
|
||||
<window_info anchor="bottom" id="Find" />
|
||||
<window_info anchor="bottom" id="Version Control" order="0" />
|
||||
<window_info anchor="bottom" id="Problems" order="1" />
|
||||
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
||||
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
|
||||
<window_info active="true" anchor="bottom" id="Services" order="4" visible="true" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Terminal" order="3" weight="0.3296646" />
|
||||
<window_info anchor="bottom" id="Services" order="4" />
|
||||
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
||||
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
||||
<window_info active="true" anchor="bottom" id="Run" order="7" visible="true" weight="0.3296646" />
|
||||
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
|
||||
<window_info anchor="right" id="Coverage" side_tool="true" />
|
||||
<window_info anchor="right" id="SciView" show_stripe_button="false" />
|
||||
<window_info anchor="right" id="Jupyter Variables" side_tool="true" />
|
||||
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
|
||||
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
|
||||
<window_info anchor="right" id="Database" order="2" weight="0.25" />
|
||||
|
|
@ -117,7 +125,7 @@
|
|||
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||
</key>
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.STGODE">
|
||||
<component name="RunManager" selected="Python.STGODE (1)">
|
||||
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="Project-I" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
|
|
@ -142,6 +150,29 @@
|
|||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="STGODE (1)" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="Project-I" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="WORKING_DIRECTORY" value="" />
|
||||
<option name="IS_MODULE_SDK" value="true" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="--config ./configs/STGODE_LLM_GPT2/PEMS08.yaml" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="Project-I" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
|
|
@ -169,12 +200,14 @@
|
|||
<list>
|
||||
<item itemvalue="Python.STDEN" />
|
||||
<item itemvalue="Python.STGODE" />
|
||||
<item itemvalue="Python.STGODE (1)" />
|
||||
</list>
|
||||
</component>
|
||||
<component name="SharedIndexes">
|
||||
<attachedChunks>
|
||||
<set>
|
||||
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
|
||||
<option value="bundled-js-predefined-d6986cc7102b-e03c56caf84a-JavaScript-PY-252.23892.515" />
|
||||
<option value="bundled-python-sdk-7e47963ff851-f0eec537fc84-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.23892.515" />
|
||||
</set>
|
||||
</attachedChunks>
|
||||
</component>
|
||||
|
|
@ -189,6 +222,8 @@
|
|||
<workItem from="1756727623101" duration="4721000" />
|
||||
<workItem from="1756856673845" duration="652000" />
|
||||
<workItem from="1756864144998" duration="1063000" />
|
||||
<workItem from="1756960597140" duration="1062000" />
|
||||
<workItem from="1756965151878" duration="505000" />
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
|
|
@ -207,8 +242,9 @@
|
|||
</breakpoint-manager>
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE_LLM.coverage" NAME="STGODE-LLM 覆盖结果" MODIFIED="1756950739801" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756885209907" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE__1_.coverage" NAME="STGODE (1) 覆盖结果" MODIFIED="1756965216400" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
</component>
|
||||
</project>
|
||||
19
README.md
19
README.md
|
|
@ -2,19 +2,34 @@
|
|||
|
||||
Secret Projct
|
||||
|
||||
```
|
||||
mkdir -p models/gpt2
|
||||
```
|
||||
|
||||
## Prepare Env.
|
||||
```
|
||||
pip install -r requirement.txt
|
||||
```
|
||||
|
||||
## Download dataset
|
||||
```
|
||||
python utils/download.py
|
||||
```
|
||||
|
||||
## Download gpt weight
|
||||
|
||||
mkdir -p models/gpt2
|
||||
`mkdir -p models/gpt2`
|
||||
|
||||
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
|
||||
|
||||
```bash
|
||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
|
||||
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
|
||||
```
|
||||
Use pytorch >= 2.6 to load model.
|
||||
|
||||
## Run
|
||||
|
||||
Run: `python.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
|
||||
```
|
||||
`python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
|
||||
```
|
||||
|
|
@ -1,296 +0,0 @@
|
|||
from,to,cost
|
||||
9,153,310.6
|
||||
153,62,330.9
|
||||
62,111,332.9
|
||||
111,11,324.2
|
||||
11,28,336.0
|
||||
28,169,133.7
|
||||
138,135,354.7
|
||||
135,133,387.9
|
||||
133,163,337.1
|
||||
163,20,352.0
|
||||
20,19,420.8
|
||||
19,14,351.3
|
||||
14,39,340.2
|
||||
39,164,350.3
|
||||
164,167,365.2
|
||||
167,70,359.0
|
||||
70,59,388.2
|
||||
59,58,305.7
|
||||
58,67,294.4
|
||||
67,66,299.5
|
||||
66,55,313.3
|
||||
55,53,332.1
|
||||
53,150,278.9
|
||||
150,61,308.4
|
||||
61,64,311.4
|
||||
64,63,243.6
|
||||
47,65,372.8
|
||||
65,48,319.4
|
||||
48,49,309.7
|
||||
49,54,320.5
|
||||
54,56,318.3
|
||||
56,57,297.9
|
||||
57,68,293.5
|
||||
68,69,342.5
|
||||
69,60,318.0
|
||||
60,17,305.9
|
||||
17,5,321.4
|
||||
5,18,402.2
|
||||
18,22,447.4
|
||||
22,30,377.5
|
||||
30,29,417.7
|
||||
29,21,360.8
|
||||
21,132,407.6
|
||||
132,134,386.9
|
||||
134,136,350.2
|
||||
123,121,326.3
|
||||
121,140,385.2
|
||||
140,118,393.0
|
||||
118,96,296.7
|
||||
96,94,398.2
|
||||
94,86,337.1
|
||||
86,78,473.8
|
||||
78,46,353.4
|
||||
46,152,385.7
|
||||
152,157,350.0
|
||||
157,35,354.4
|
||||
35,77,356.1
|
||||
77,52,354.2
|
||||
52,3,357.8
|
||||
3,16,382.4
|
||||
16,0,55.7
|
||||
42,12,335.1
|
||||
12,139,328.8
|
||||
139,168,412.6
|
||||
168,154,337.3
|
||||
154,143,370.7
|
||||
143,10,6.3
|
||||
107,105,354.6
|
||||
105,104,386.9
|
||||
104,148,362.1
|
||||
148,97,316.3
|
||||
97,101,380.7
|
||||
101,137,361.4
|
||||
137,102,365.5
|
||||
102,24,375.5
|
||||
24,166,312.2
|
||||
129,156,256.1
|
||||
156,33,329.1
|
||||
33,32,356.5
|
||||
91,89,405.6
|
||||
89,147,347.0
|
||||
147,15,351.7
|
||||
15,44,339.5
|
||||
44,41,350.8
|
||||
41,43,322.6
|
||||
43,100,338.9
|
||||
100,83,347.9
|
||||
83,87,327.2
|
||||
87,88,321.0
|
||||
88,75,335.8
|
||||
75,51,384.8
|
||||
51,73,391.1
|
||||
73,71,289.3
|
||||
31,155,260.0
|
||||
155,34,320.4
|
||||
34,128,393.3
|
||||
145,115,399.4
|
||||
115,112,328.1
|
||||
112,8,469.4
|
||||
8,117,816.2
|
||||
117,125,397.1
|
||||
125,127,372.7
|
||||
127,109,380.5
|
||||
109,161,355.5
|
||||
161,110,367.7
|
||||
110,160,102.0
|
||||
72,159,342.9
|
||||
159,50,383.3
|
||||
50,74,354.1
|
||||
74,82,350.2
|
||||
82,81,335.4
|
||||
81,99,391.6
|
||||
99,84,354.9
|
||||
84,13,306.4
|
||||
13,40,327.4
|
||||
40,162,413.9
|
||||
162,108,301.9
|
||||
108,146,317.8
|
||||
146,85,376.6
|
||||
85,90,347.0
|
||||
26,27,341.6
|
||||
27,6,359.4
|
||||
6,149,417.8
|
||||
149,126,388.0
|
||||
126,124,384.3
|
||||
124,7,763.3
|
||||
7,114,323.1
|
||||
114,113,351.6
|
||||
113,116,411.9
|
||||
116,144,262.0
|
||||
25,103,350.2
|
||||
103,23,376.3
|
||||
23,165,396.4
|
||||
165,38,381.0
|
||||
38,92,368.0
|
||||
92,37,336.3
|
||||
37,130,357.8
|
||||
130,106,532.3
|
||||
106,131,166.5
|
||||
1,2,371.6
|
||||
2,4,338.1
|
||||
4,76,429.0
|
||||
76,36,366.1
|
||||
36,158,344.5
|
||||
158,151,350.1
|
||||
151,45,358.8
|
||||
45,93,340.9
|
||||
93,80,329.9
|
||||
80,79,384.1
|
||||
79,95,335.7
|
||||
95,98,320.9
|
||||
98,119,340.3
|
||||
119,120,376.8
|
||||
120,122,393.1
|
||||
122,141,428.7
|
||||
141,142,359.3
|
||||
30,165,379.6
|
||||
165,29,41.7
|
||||
29,38,343.3
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
17,153,375.6
|
||||
153,5,256.3
|
||||
153,62,330.9
|
||||
18,6,499.4
|
||||
6,22,254.0
|
||||
22,149,185.4
|
||||
22,4,257.9
|
||||
4,30,236.8
|
||||
30,76,307.0
|
||||
95,98,320.9
|
||||
98,144,45.1
|
||||
45,93,340.9
|
||||
93,106,112.2
|
||||
162,151,113.6
|
||||
151,108,192.9
|
||||
108,45,359.8
|
||||
146,92,311.2
|
||||
92,85,343.9
|
||||
85,37,373.2
|
||||
13,169,326.2
|
||||
169,40,96.1
|
||||
124,13,460.7
|
||||
13,7,305.5
|
||||
7,40,624.1
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
90,132,152.2
|
||||
26,32,106.7
|
||||
9,129,148.3
|
||||
129,153,219.6
|
||||
31,26,116.0
|
||||
26,155,270.7
|
||||
9,128,142.2
|
||||
128,153,215.0
|
||||
153,167,269.7
|
||||
167,62,64.8
|
||||
62,70,332.6
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
44,169,397.8
|
||||
169,41,124.0
|
||||
44,124,375.7
|
||||
124,41,243.9
|
||||
41,7,519.4
|
||||
6,14,289.3
|
||||
14,149,259.0
|
||||
149,39,206.9
|
||||
144,98,45.1
|
||||
19,4,326.8
|
||||
4,14,178.6
|
||||
14,76,299.0
|
||||
15,151,136.4
|
||||
151,44,203.1
|
||||
45,106,260.6
|
||||
106,93,112.2
|
||||
20,165,132.5
|
||||
165,19,289.2
|
||||
89,92,323.2
|
||||
92,147,321.9
|
||||
147,37,48.2
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
104,147,277.5
|
||||
147,148,84.7
|
||||
20,101,201.2
|
||||
101,19,534.4
|
||||
19,137,245.5
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
44,42,342.3
|
||||
42,41,102.5
|
||||
44,8,789.1
|
||||
8,41,657.4
|
||||
41,117,160.5
|
||||
168,167,172.4
|
||||
167,154,165.2
|
||||
143,128,81.9
|
||||
128,10,88.2
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
15,152,135.0
|
||||
152,44,204.6
|
||||
19,77,320.7
|
||||
77,14,299.8
|
||||
14,52,127.6
|
||||
14,127,314.8
|
||||
127,39,280.4
|
||||
39,109,237.0
|
||||
31,160,116.5
|
||||
160,155,272.4
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
32,160,107.7
|
||||
72,162,3274.4
|
||||
162,13,554.5
|
||||
162,40,413.9
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
13,42,319.8
|
||||
42,40,40.7
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
8,13,450.3
|
||||
13,117,378.5
|
||||
117,40,64.0
|
||||
46,162,391.6
|
||||
162,152,115.3
|
||||
152,108,191.4
|
||||
104,108,375.9
|
||||
108,148,311.6
|
||||
148,146,80.0
|
||||
21,90,396.9
|
||||
90,132,152.2
|
||||
101,29,252.3
|
||||
29,137,110.7
|
||||
77,22,353.8
|
||||
22,52,227.8
|
||||
52,30,186.6
|
||||
127,18,425.2
|
||||
18,109,439.1
|
||||
109,22,135.5
|
||||
168,17,232.7
|
||||
17,154,294.2
|
||||
154,5,166.3
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
|
Binary file not shown.
|
|
@ -0,0 +1,220 @@
|
|||
import csv
|
||||
import os
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from scipy.sparse import coo_matrix
|
||||
from scipy.sparse.linalg import norm
|
||||
import scipy.sparse as sp
|
||||
import torch
|
||||
|
||||
def get_adj(args):
|
||||
dataset_path = './data'
|
||||
match args['num_nodes']:
|
||||
case 358:
|
||||
dataset_name = 'PEMS03'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||
case 307:
|
||||
dataset_name = 'PEMS04'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
case 883:
|
||||
dataset_name = 'PEMS07'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||
case 170:
|
||||
dataset_name = 'PEMS08'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
|
||||
return adj
|
||||
|
||||
def get_gso(args):
|
||||
dataset_path = './data'
|
||||
match args['num_nodes']:
|
||||
case 358:
|
||||
dataset_name = 'PEMS03'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
|
||||
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
|
||||
case 307:
|
||||
dataset_name = 'PEMS04'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
case 883:
|
||||
dataset_name = 'PEMS07'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path)
|
||||
case 170:
|
||||
dataset_name = 'PEMS08'
|
||||
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
|
||||
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
|
||||
|
||||
gso = calc_gso(adj, args['gso_type'])
|
||||
if args['graph_conv_type'] == 'cheb_graph_conv':
|
||||
gso = calc_chebynet_gso(gso)
|
||||
gso = gso.toarray()
|
||||
gso = gso.astype(dtype=np.float32)
|
||||
gso = torch.from_numpy(gso).to(args['device'])
|
||||
return gso
|
||||
|
||||
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
|
||||
'''
|
||||
Parameters
|
||||
----------
|
||||
adj_path: str, path of the csv file contains edges information
|
||||
num_nodes: int, the number of vertices
|
||||
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
|
||||
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
|
||||
|
||||
Returns
|
||||
----------
|
||||
A: np.ndarray, adjacency matrix
|
||||
distanceA: np.ndarray, distance matrix (normalized if std=True)
|
||||
'''
|
||||
if 'npy' in adj_path:
|
||||
adj_mx = np.load(adj_path)
|
||||
return adj_mx, None
|
||||
|
||||
else:
|
||||
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
|
||||
|
||||
# 如果提供了id_filename,说明节点ID不是从0开始的,需要重新映射
|
||||
if id_filename:
|
||||
with open(id_filename, 'r') as f:
|
||||
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
|
||||
|
||||
with open(adj_path, 'r') as f:
|
||||
f.readline() # 略过表头那一行
|
||||
reader = csv.reader(f)
|
||||
costs = [] # 用于收集所有cost值
|
||||
for row in reader:
|
||||
if len(row) != 3:
|
||||
continue
|
||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||
A[id_dict[i], id_dict[j]] = 1
|
||||
# 确保距离值为正
|
||||
distance = max(distance, 1e-6)
|
||||
costs.append(distance) # 收集cost值
|
||||
distanceA[id_dict[i], id_dict[j]] = distance
|
||||
|
||||
else: # 如果没有提供id_filename,说明节点ID是从0开始的
|
||||
with open(adj_path, 'r') as f:
|
||||
f.readline() # 略过表头那一行
|
||||
reader = csv.reader(f)
|
||||
costs = [] # 用于收集所有cost值
|
||||
for row in reader:
|
||||
if len(row) != 3:
|
||||
continue
|
||||
i, j, distance = int(row[0]), int(row[1]), float(row[2])
|
||||
A[i, j] = 1
|
||||
# 确保距离值为正
|
||||
distance = max(distance, 1e-6)
|
||||
costs.append(distance) # 收集cost值
|
||||
distanceA[i, j] = distance
|
||||
|
||||
# 如果std=True,对CSV中的所有cost值进行高斯正态分布标准化
|
||||
if std:
|
||||
mean_cost = np.mean(costs) # 计算cost值的均值
|
||||
std_cost = np.std(costs) # 计算cost值的标准差
|
||||
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
|
||||
if distanceA[idx] > 0: # 只对非零元素进行标准化
|
||||
normalized_value = (distanceA[idx] - mean_cost) / std_cost
|
||||
# 确保标准化后的值为正
|
||||
normalized_value = max(normalized_value, 1e-6)
|
||||
distanceA[idx] = normalized_value
|
||||
|
||||
# 确保矩阵中没有零行
|
||||
row_sums = distanceA.sum(axis=1)
|
||||
zero_rows = np.where(row_sums == 0)[0]
|
||||
for row in zero_rows:
|
||||
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
|
||||
|
||||
return A, distanceA
|
||||
|
||||
|
||||
def calc_gso(dir_adj, gso_type):
|
||||
n_vertex = dir_adj.shape[0]
|
||||
|
||||
if not sp.issparse(dir_adj):
|
||||
dir_adj = sp.csc_matrix(dir_adj)
|
||||
elif dir_adj.format != 'csc':
|
||||
dir_adj = dir_adj.tocsc()
|
||||
|
||||
id = sp.identity(n_vertex, format='csc')
|
||||
|
||||
# Symmetrizing an adjacency matrix
|
||||
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
|
||||
# adj = 0.5 * (dir_adj + dir_adj.transpose())
|
||||
|
||||
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
|
||||
adj = adj + id
|
||||
|
||||
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
|
||||
row_sum = adj.sum(axis=1).A1
|
||||
# Check for zero or negative values in row_sum
|
||||
if np.any(row_sum <= 0):
|
||||
raise ValueError(
|
||||
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
|
||||
|
||||
row_sum_inv_sqrt = np.power(row_sum, -0.5)
|
||||
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
|
||||
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
|
||||
# A_{sym} = D^{-0.5} * A * D^{-0.5}
|
||||
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
|
||||
|
||||
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
|
||||
sym_norm_lap = id - sym_norm_adj
|
||||
gso = sym_norm_lap
|
||||
else:
|
||||
gso = sym_norm_adj
|
||||
|
||||
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
|
||||
row_sum = np.sum(adj, axis=1).A1
|
||||
# Check for zero or negative values in row_sum
|
||||
if np.any(row_sum <= 0):
|
||||
raise ValueError(
|
||||
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
|
||||
|
||||
row_sum_inv = np.power(row_sum, -1)
|
||||
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
|
||||
deg_inv = sp.diags(row_sum_inv, format='csc')
|
||||
# A_{rw} = D^{-1} * A
|
||||
rw_norm_adj = deg_inv.dot(adj)
|
||||
|
||||
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
|
||||
rw_norm_lap = id - rw_norm_adj
|
||||
gso = rw_norm_lap
|
||||
else:
|
||||
gso = rw_norm_adj
|
||||
|
||||
else:
|
||||
raise ValueError(f'{gso_type} is not defined.')
|
||||
|
||||
# Check for nan or inf in the final result
|
||||
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
|
||||
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
|
||||
|
||||
return gso
|
||||
|
||||
|
||||
def calc_chebynet_gso(gso):
|
||||
if sp.issparse(gso) == False:
|
||||
gso = sp.csc_matrix(gso)
|
||||
elif gso.format != 'csc':
|
||||
gso = gso.tocsc()
|
||||
|
||||
id = sp.identity(gso.shape[0], format='csc')
|
||||
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
|
||||
eigval_max = norm(gso, 2)
|
||||
|
||||
# If the gso is symmetric or random walk normalized Laplacian,
|
||||
# then the maximum eigenvalue is smaller than or equals to 2.
|
||||
if eigval_max >= 2:
|
||||
gso = gso - id
|
||||
else:
|
||||
gso = 2 * gso / eigval_max - id
|
||||
|
||||
return gso
|
||||
|
|
@ -66,11 +66,11 @@ class GPT2BackboneHF(nn.Module):
|
|||
super().__init__()
|
||||
from transformers import GPT2Model
|
||||
if local_dir is not None and len(local_dir) > 0:
|
||||
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True)
|
||||
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False)
|
||||
else:
|
||||
if model_name is None:
|
||||
model_name = 'gpt2'
|
||||
self.model = GPT2Model.from_pretrained(model_name)
|
||||
self.model = GPT2Model.from_pretrained(model_name, use_cache=False)
|
||||
if gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable()
|
||||
self.hidden_size = self.model.config.hidden_size
|
||||
|
|
|
|||
|
|
@ -1,12 +1,16 @@
|
|||
from models.STDEN.stden_model import STDENModel
|
||||
from models.STGODE.STGODE import ODEGCN
|
||||
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
|
||||
|
||||
|
||||
def model_selector(config):
|
||||
model_name = config['basic']['model']
|
||||
model = None
|
||||
match model_name:
|
||||
case 'STDEN':
|
||||
case 'STDEN':
|
||||
model = STDENModel(config)
|
||||
case 'STGODE':
|
||||
model = ODEGCN(config)
|
||||
return model
|
||||
case 'STGODE-LLM-GPT2':
|
||||
model = ODEGCN_LLM_GPT2(config)
|
||||
return model
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ matplotlib
|
|||
requests
|
||||
|
||||
# 文件压缩处理
|
||||
zipfile
|
||||
|
||||
# Kaggle数据下载
|
||||
kagglehub
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ def download_adj_data(current_dir, max_retries=3):
|
|||
下载并解压 adj.zip 文件,并显示下载进度条。
|
||||
如果下载失败,最多重试 max_retries 次。
|
||||
"""
|
||||
url = "https://code.zhang-heng.com/static/adj.zip"
|
||||
url = "http://code.zhang-heng.com/static/adj.zip"
|
||||
retries = 0
|
||||
|
||||
while retries <= max_retries:
|
||||
|
|
@ -143,4 +143,4 @@ def download_kaggle_data(current_dir):
|
|||
|
||||
# 主程序
|
||||
if __name__ == "__main__":
|
||||
check_and_download_data()
|
||||
check_and_download_data()
|
||||
|
|
|
|||
Loading…
Reference in New Issue