Compare commits

...

11 Commits
STGODE ... main

29 changed files with 902 additions and 593 deletions

20
.gitignore vendored
View File

@ -160,8 +160,24 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# 数据集文件屏蔽
.STDEN/
.data/PEMS08/
exp/
STDEN/
models/gpt2/
models/gpt2/
pre-trained/
# 数据集文件类型屏蔽
*.csv
*.npz
*.npy
*.pkl
*.h5
*.hdf5
# 特定数据集目录屏蔽
data/PEMS08/
data/PEMSD8/
test_semantic.npy
test_spatial.npy

8
.idea/.gitignore vendored
View File

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -1,27 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="136" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="8">
<item index="0" class="java.lang.String" itemvalue="argparse" />
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
<item index="5" class="java.lang.String" itemvalue="setuptools" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
</project>

View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
</modules>
</component>
</project>

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

View File

@ -1,216 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="Git.Settings">
<excluded-from-favorite>
<branch-storage>
<map>
<entry type="LOCAL">
<value>
<list>
<branch-info repo="$PROJECT_DIR$" source="main" />
</list>
</value>
</entry>
</map>
</branch-storage>
</excluded-from-favorite>
<option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="main" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="ROOT_SYNC" value="DONT_SYNC" />
</component>
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
<component name="ProjectColorInfo">{
&quot;associatedIndex&quot;: 3
}</component>
<component name="ProjectId" id="3264JlB7seHXuXCCcdmTyEsXI45" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"Python.STDEN.executor": "Debug",
"Python.STGODE.executor": "Run",
"Python.main.executor": "Run",
"RunOnceActivity.OpenProjectViewOnStart": "true",
"RunOnceActivity.ShowReadmeOnStart": "true",
"git-widget-placeholder": "STGODE",
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
"node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true",
"node.js.selected.package.eslint": "(autodetect)",
"node.js.selected.package.tslint": "(autodetect)",
"nodejs_package_manager_path": "npm",
"vue.rearranger.settings.migration": "true"
}
}]]></component>
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
<layout>
<window_info id="Space Code Reviews" show_stripe_button="false" />
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
<window_info id="Merge Requests" show_stripe_button="false" />
<window_info id="Commit_Guest" show_stripe_button="false" />
<window_info id="Pull Requests" show_stripe_button="false" />
<window_info id="Learn" show_stripe_button="false" />
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
<window_info id="Commit" order="1" weight="0.25" />
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
<window_info anchor="bottom" id="Version Control" order="0" />
<window_info anchor="bottom" id="Problems" order="1" />
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
<window_info anchor="bottom" id="Services" order="4" />
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
<window_info anchor="right" id="SciView" show_stripe_button="false" />
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
<window_info anchor="right" id="Database" order="2" weight="0.25" />
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
<window_info anchor="right" id="Plots" order="5" weight="0.1" />
</layout>
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/trainer" />
<recent name="$PROJECT_DIR$/configs/STDEN" />
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
</component>
<component name="RunManager" selected="Python.STGODE">
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STDEN/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.STDEN" />
<item itemvalue="Python.STGODE" />
</list>
</component>
<component name="SharedIndexes">
<attachedChunks>
<set>
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
</set>
</attachedChunks>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="默认任务">
<changelist id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="" />
<created>1756727620810</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1756727620810</updated>
<workItem from="1756727623101" duration="4721000" />
<workItem from="1756856673845" duration="652000" />
<workItem from="1756864144998" duration="1063000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/models/STDEN/stden_model.py</url>
<line>131</line>
<option name="timeStamp" value="5" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
</component>
</project>

View File

@ -1,3 +1,35 @@
# Project-I
Secret Projct
Secret Projct
```
mkdir -p models/gpt2
```
## Prepare Env.
```
pip install -r requirement.txt
```
## Download dataset
```
python utils/download.py
```
## Download gpt weight
`mkdir -p models/gpt2`
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
```bash
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
```
Use pytorch >= 2.6 to load model.
## Run
```
python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
```

1
STDEN

@ -1 +0,0 @@
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2

View File

@ -0,0 +1,65 @@
basic:
device: cuda:0
dataset: PEMS08
model: STGODE-LLM
mode: test
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 32
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 64
input_dim: 1
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
rnn_units: 64
sigma1: 0.1
sigma2: 10
thres1: 0.6
thres2: 0.5
# LLM backbone settings
llm_hidden: 128
llm_layers: 4
llm_heads: 4
llm_pretrained: True
train:
loss: mae
batch_size: 64
epochs: 100
lr_init: 0.003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
log_step: 3000

View File

@ -0,0 +1,66 @@
basic:
device: cuda:0
dataset: PEMS08
model: STGODE-LLM-GPT2
mode: train
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 16
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 32
input_dim: 1
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
rnn_units: 64
sigma1: 0.1
sigma2: 10
thres1: 0.6
thres2: 0.5
# HF GPT-2 settings
gpt2_name: gpt2
gpt2_grad_ckpt: True
gpt2_freeze: True
gpt2_local_dir: ./models/gpt2
train:
loss: mae
batch_size: 32
epochs: 100
lr_init: 0.0003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "10,30,60,90"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
log_step: 3000

View File

@ -1,296 +0,0 @@
from,to,cost
9,153,310.6
153,62,330.9
62,111,332.9
111,11,324.2
11,28,336.0
28,169,133.7
138,135,354.7
135,133,387.9
133,163,337.1
163,20,352.0
20,19,420.8
19,14,351.3
14,39,340.2
39,164,350.3
164,167,365.2
167,70,359.0
70,59,388.2
59,58,305.7
58,67,294.4
67,66,299.5
66,55,313.3
55,53,332.1
53,150,278.9
150,61,308.4
61,64,311.4
64,63,243.6
47,65,372.8
65,48,319.4
48,49,309.7
49,54,320.5
54,56,318.3
56,57,297.9
57,68,293.5
68,69,342.5
69,60,318.0
60,17,305.9
17,5,321.4
5,18,402.2
18,22,447.4
22,30,377.5
30,29,417.7
29,21,360.8
21,132,407.6
132,134,386.9
134,136,350.2
123,121,326.3
121,140,385.2
140,118,393.0
118,96,296.7
96,94,398.2
94,86,337.1
86,78,473.8
78,46,353.4
46,152,385.7
152,157,350.0
157,35,354.4
35,77,356.1
77,52,354.2
52,3,357.8
3,16,382.4
16,0,55.7
42,12,335.1
12,139,328.8
139,168,412.6
168,154,337.3
154,143,370.7
143,10,6.3
107,105,354.6
105,104,386.9
104,148,362.1
148,97,316.3
97,101,380.7
101,137,361.4
137,102,365.5
102,24,375.5
24,166,312.2
129,156,256.1
156,33,329.1
33,32,356.5
91,89,405.6
89,147,347.0
147,15,351.7
15,44,339.5
44,41,350.8
41,43,322.6
43,100,338.9
100,83,347.9
83,87,327.2
87,88,321.0
88,75,335.8
75,51,384.8
51,73,391.1
73,71,289.3
31,155,260.0
155,34,320.4
34,128,393.3
145,115,399.4
115,112,328.1
112,8,469.4
8,117,816.2
117,125,397.1
125,127,372.7
127,109,380.5
109,161,355.5
161,110,367.7
110,160,102.0
72,159,342.9
159,50,383.3
50,74,354.1
74,82,350.2
82,81,335.4
81,99,391.6
99,84,354.9
84,13,306.4
13,40,327.4
40,162,413.9
162,108,301.9
108,146,317.8
146,85,376.6
85,90,347.0
26,27,341.6
27,6,359.4
6,149,417.8
149,126,388.0
126,124,384.3
124,7,763.3
7,114,323.1
114,113,351.6
113,116,411.9
116,144,262.0
25,103,350.2
103,23,376.3
23,165,396.4
165,38,381.0
38,92,368.0
92,37,336.3
37,130,357.8
130,106,532.3
106,131,166.5
1,2,371.6
2,4,338.1
4,76,429.0
76,36,366.1
36,158,344.5
158,151,350.1
151,45,358.8
45,93,340.9
93,80,329.9
80,79,384.1
79,95,335.7
95,98,320.9
98,119,340.3
119,120,376.8
120,122,393.1
122,141,428.7
141,142,359.3
30,165,379.6
165,29,41.7
29,38,343.3
65,72,297.9
72,48,21.5
17,153,375.6
153,5,256.3
153,62,330.9
18,6,499.4
6,22,254.0
22,149,185.4
22,4,257.9
4,30,236.8
30,76,307.0
95,98,320.9
98,144,45.1
45,93,340.9
93,106,112.2
162,151,113.6
151,108,192.9
108,45,359.8
146,92,311.2
92,85,343.9
85,37,373.2
13,169,326.2
169,40,96.1
124,13,460.7
13,7,305.5
7,40,624.1
124,169,145.2
169,7,631.5
90,132,152.2
26,32,106.7
9,129,148.3
129,153,219.6
31,26,116.0
26,155,270.7
9,128,142.2
128,153,215.0
153,167,269.7
167,62,64.8
62,70,332.6
124,169,145.2
169,7,631.5
44,169,397.8
169,41,124.0
44,124,375.7
124,41,243.9
41,7,519.4
6,14,289.3
14,149,259.0
149,39,206.9
144,98,45.1
19,4,326.8
4,14,178.6
14,76,299.0
15,151,136.4
151,44,203.1
45,106,260.6
106,93,112.2
20,165,132.5
165,19,289.2
89,92,323.2
92,147,321.9
147,37,48.2
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
78,107,143.9
107,46,236.3
104,147,277.5
147,148,84.7
20,101,201.2
101,19,534.4
19,137,245.5
8,42,759.5
42,117,58.9
44,42,342.3
42,41,102.5
44,8,789.1
8,41,657.4
41,117,160.5
168,167,172.4
167,154,165.2
143,128,81.9
128,10,88.2
118,145,250.6
145,96,85.1
15,152,135.0
152,44,204.6
19,77,320.7
77,14,299.8
14,52,127.6
14,127,314.8
127,39,280.4
39,109,237.0
31,160,116.5
160,155,272.4
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
32,160,107.7
72,162,3274.4
162,13,554.5
162,40,413.9
65,72,297.9
72,48,21.5
13,42,319.8
42,40,40.7
8,42,759.5
42,117,58.9
8,13,450.3
13,117,378.5
117,40,64.0
46,162,391.6
162,152,115.3
152,108,191.4
104,108,375.9
108,148,311.6
148,146,80.0
21,90,396.9
90,132,152.2
101,29,252.3
29,137,110.7
77,22,353.8
22,52,227.8
52,30,186.6
127,18,425.2
18,109,439.1
109,22,135.5
168,17,232.7
17,154,294.2
154,5,166.3
78,107,143.9
107,46,236.3
118,145,250.6
145,96,85.1
1 from to cost
2 9 153 310.6
3 153 62 330.9
4 62 111 332.9
5 111 11 324.2
6 11 28 336.0
7 28 169 133.7
8 138 135 354.7
9 135 133 387.9
10 133 163 337.1
11 163 20 352.0
12 20 19 420.8
13 19 14 351.3
14 14 39 340.2
15 39 164 350.3
16 164 167 365.2
17 167 70 359.0
18 70 59 388.2
19 59 58 305.7
20 58 67 294.4
21 67 66 299.5
22 66 55 313.3
23 55 53 332.1
24 53 150 278.9
25 150 61 308.4
26 61 64 311.4
27 64 63 243.6
28 47 65 372.8
29 65 48 319.4
30 48 49 309.7
31 49 54 320.5
32 54 56 318.3
33 56 57 297.9
34 57 68 293.5
35 68 69 342.5
36 69 60 318.0
37 60 17 305.9
38 17 5 321.4
39 5 18 402.2
40 18 22 447.4
41 22 30 377.5
42 30 29 417.7
43 29 21 360.8
44 21 132 407.6
45 132 134 386.9
46 134 136 350.2
47 123 121 326.3
48 121 140 385.2
49 140 118 393.0
50 118 96 296.7
51 96 94 398.2
52 94 86 337.1
53 86 78 473.8
54 78 46 353.4
55 46 152 385.7
56 152 157 350.0
57 157 35 354.4
58 35 77 356.1
59 77 52 354.2
60 52 3 357.8
61 3 16 382.4
62 16 0 55.7
63 42 12 335.1
64 12 139 328.8
65 139 168 412.6
66 168 154 337.3
67 154 143 370.7
68 143 10 6.3
69 107 105 354.6
70 105 104 386.9
71 104 148 362.1
72 148 97 316.3
73 97 101 380.7
74 101 137 361.4
75 137 102 365.5
76 102 24 375.5
77 24 166 312.2
78 129 156 256.1
79 156 33 329.1
80 33 32 356.5
81 91 89 405.6
82 89 147 347.0
83 147 15 351.7
84 15 44 339.5
85 44 41 350.8
86 41 43 322.6
87 43 100 338.9
88 100 83 347.9
89 83 87 327.2
90 87 88 321.0
91 88 75 335.8
92 75 51 384.8
93 51 73 391.1
94 73 71 289.3
95 31 155 260.0
96 155 34 320.4
97 34 128 393.3
98 145 115 399.4
99 115 112 328.1
100 112 8 469.4
101 8 117 816.2
102 117 125 397.1
103 125 127 372.7
104 127 109 380.5
105 109 161 355.5
106 161 110 367.7
107 110 160 102.0
108 72 159 342.9
109 159 50 383.3
110 50 74 354.1
111 74 82 350.2
112 82 81 335.4
113 81 99 391.6
114 99 84 354.9
115 84 13 306.4
116 13 40 327.4
117 40 162 413.9
118 162 108 301.9
119 108 146 317.8
120 146 85 376.6
121 85 90 347.0
122 26 27 341.6
123 27 6 359.4
124 6 149 417.8
125 149 126 388.0
126 126 124 384.3
127 124 7 763.3
128 7 114 323.1
129 114 113 351.6
130 113 116 411.9
131 116 144 262.0
132 25 103 350.2
133 103 23 376.3
134 23 165 396.4
135 165 38 381.0
136 38 92 368.0
137 92 37 336.3
138 37 130 357.8
139 130 106 532.3
140 106 131 166.5
141 1 2 371.6
142 2 4 338.1
143 4 76 429.0
144 76 36 366.1
145 36 158 344.5
146 158 151 350.1
147 151 45 358.8
148 45 93 340.9
149 93 80 329.9
150 80 79 384.1
151 79 95 335.7
152 95 98 320.9
153 98 119 340.3
154 119 120 376.8
155 120 122 393.1
156 122 141 428.7
157 141 142 359.3
158 30 165 379.6
159 165 29 41.7
160 29 38 343.3
161 65 72 297.9
162 72 48 21.5
163 17 153 375.6
164 153 5 256.3
165 153 62 330.9
166 18 6 499.4
167 6 22 254.0
168 22 149 185.4
169 22 4 257.9
170 4 30 236.8
171 30 76 307.0
172 95 98 320.9
173 98 144 45.1
174 45 93 340.9
175 93 106 112.2
176 162 151 113.6
177 151 108 192.9
178 108 45 359.8
179 146 92 311.2
180 92 85 343.9
181 85 37 373.2
182 13 169 326.2
183 169 40 96.1
184 124 13 460.7
185 13 7 305.5
186 7 40 624.1
187 124 169 145.2
188 169 7 631.5
189 90 132 152.2
190 26 32 106.7
191 9 129 148.3
192 129 153 219.6
193 31 26 116.0
194 26 155 270.7
195 9 128 142.2
196 128 153 215.0
197 153 167 269.7
198 167 62 64.8
199 62 70 332.6
200 124 169 145.2
201 169 7 631.5
202 44 169 397.8
203 169 41 124.0
204 44 124 375.7
205 124 41 243.9
206 41 7 519.4
207 6 14 289.3
208 14 149 259.0
209 149 39 206.9
210 144 98 45.1
211 19 4 326.8
212 4 14 178.6
213 14 76 299.0
214 15 151 136.4
215 151 44 203.1
216 45 106 260.6
217 106 93 112.2
218 20 165 132.5
219 165 19 289.2
220 89 92 323.2
221 92 147 321.9
222 147 37 48.2
223 133 91 152.8
224 91 163 313.6
225 150 71 221.1
226 71 61 89.6
227 78 107 143.9
228 107 46 236.3
229 104 147 277.5
230 147 148 84.7
231 20 101 201.2
232 101 19 534.4
233 19 137 245.5
234 8 42 759.5
235 42 117 58.9
236 44 42 342.3
237 42 41 102.5
238 44 8 789.1
239 8 41 657.4
240 41 117 160.5
241 168 167 172.4
242 167 154 165.2
243 143 128 81.9
244 128 10 88.2
245 118 145 250.6
246 145 96 85.1
247 15 152 135.0
248 152 44 204.6
249 19 77 320.7
250 77 14 299.8
251 14 52 127.6
252 14 127 314.8
253 127 39 280.4
254 39 109 237.0
255 31 160 116.5
256 160 155 272.4
257 133 91 152.8
258 91 163 313.6
259 150 71 221.1
260 71 61 89.6
261 32 160 107.7
262 72 162 3274.4
263 162 13 554.5
264 162 40 413.9
265 65 72 297.9
266 72 48 21.5
267 13 42 319.8
268 42 40 40.7
269 8 42 759.5
270 42 117 58.9
271 8 13 450.3
272 13 117 378.5
273 117 40 64.0
274 46 162 391.6
275 162 152 115.3
276 152 108 191.4
277 104 108 375.9
278 108 148 311.6
279 148 146 80.0
280 21 90 396.9
281 90 132 152.2
282 101 29 252.3
283 29 137 110.7
284 77 22 353.8
285 22 52 227.8
286 52 30 186.6
287 127 18 425.2
288 18 109 439.1
289 109 22 135.5
290 168 17 232.7
291 17 154 294.2
292 154 5 166.3
293 78 107 143.9
294 107 46 236.3
295 118 145 250.6
296 145 96 85.1

Binary file not shown.

Binary file not shown.

220
data/get_adj.py Normal file
View File

@ -0,0 +1,220 @@
import csv
import os
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import norm
import scipy.sparse as sp
import torch
def get_adj(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
return adj
def get_gso(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
gso = calc_gso(adj, args['gso_type'])
if args['graph_conv_type'] == 'cheb_graph_conv':
gso = calc_chebynet_gso(gso)
gso = gso.toarray()
gso = gso.astype(dtype=np.float32)
gso = torch.from_numpy(gso).to(args['device'])
return gso
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
'''
Parameters
----------
adj_path: str, path of the csv file contains edges information
num_nodes: int, the number of vertices
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
Returns
----------
A: np.ndarray, adjacency matrix
distanceA: np.ndarray, distance matrix (normalized if std=True)
'''
if 'npy' in adj_path:
adj_mx = np.load(adj_path)
return adj_mx, None
else:
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
# 如果提供了id_filename说明节点ID不是从0开始的需要重新映射
if id_filename:
with open(id_filename, 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[id_dict[i], id_dict[j]] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[id_dict[i], id_dict[j]] = distance
else: # 如果没有提供id_filename说明节点ID是从0开始的
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[i, j] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[i, j] = distance
# 如果std=True对CSV中的所有cost值进行高斯正态分布标准化
if std:
mean_cost = np.mean(costs) # 计算cost值的均值
std_cost = np.std(costs) # 计算cost值的标准差
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
if distanceA[idx] > 0: # 只对非零元素进行标准化
normalized_value = (distanceA[idx] - mean_cost) / std_cost
# 确保标准化后的值为正
normalized_value = max(normalized_value, 1e-6)
distanceA[idx] = normalized_value
# 确保矩阵中没有零行
row_sums = distanceA.sum(axis=1)
zero_rows = np.where(row_sums == 0)[0]
for row in zero_rows:
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
return A, distanceA
def calc_gso(dir_adj, gso_type):
n_vertex = dir_adj.shape[0]
if not sp.issparse(dir_adj):
dir_adj = sp.csc_matrix(dir_adj)
elif dir_adj.format != 'csc':
dir_adj = dir_adj.tocsc()
id = sp.identity(n_vertex, format='csc')
# Symmetrizing an adjacency matrix
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
# adj = 0.5 * (dir_adj + dir_adj.transpose())
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
adj = adj + id
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
row_sum = adj.sum(axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
row_sum_inv_sqrt = np.power(row_sum, -0.5)
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
# A_{sym} = D^{-0.5} * A * D^{-0.5}
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
sym_norm_lap = id - sym_norm_adj
gso = sym_norm_lap
else:
gso = sym_norm_adj
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
row_sum = np.sum(adj, axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
row_sum_inv = np.power(row_sum, -1)
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
deg_inv = sp.diags(row_sum_inv, format='csc')
# A_{rw} = D^{-1} * A
rw_norm_adj = deg_inv.dot(adj)
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
rw_norm_lap = id - rw_norm_adj
gso = rw_norm_lap
else:
gso = rw_norm_adj
else:
raise ValueError(f'{gso_type} is not defined.')
# Check for nan or inf in the final result
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
return gso
def calc_chebynet_gso(gso):
if sp.issparse(gso) == False:
gso = sp.csc_matrix(gso)
elif gso.format != 'csc':
gso = gso.tocsc()
id = sp.identity(gso.shape[0], format='csc')
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
eigval_max = norm(gso, 2)
# If the gso is symmetric or random walk normalized Laplacian,
# then the maximum eigenvalue is smaller than or equals to 2.
if eigval_max >= 2:
gso = gso - id
else:
gso = 2 * gso / eigval_max - id
return gso

View File

@ -0,0 +1,152 @@
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from models.STGODE.odegcn import ODEG
from models.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class GPT2Backbone(nn.Module):
def __init__(self, hidden_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int | None = None, use_pretrained: bool = True):
super().__init__()
self.hidden_size = hidden_size
self.use_transformers = False
self.model = None
if n_embd is None:
n_embd = hidden_size
if use_pretrained:
try:
from transformers import GPT2Model, GPT2Config
config = GPT2Config(n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_positions=1024, n_ctx=1024, vocab_size=1)
self.model = GPT2Model(config)
self.use_transformers = True
except Exception:
self.use_transformers = False
if not self.use_transformers:
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=n_head, batch_first=True)
self.model = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
if self.use_transformers:
outputs = self.model(inputs_embeds=inputs_embeds)
return outputs.last_hidden_state
else:
return self.model(inputs_embeds)
class ODEGCN_LLM(nn.Module):
def __init__(self, config):
super(ODEGCN_LLM, self).__init__()
args = config['model']
num_nodes = config['data']['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(config)
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
self.se_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
self.history = num_timesteps_input
self.horizon = num_timesteps_output
hidden_size = int(args.get('llm_hidden', 128))
llm_layers = int(args.get('llm_layers', 4))
llm_heads = int(args.get('llm_heads', 4))
use_pretrained = bool(args.get('llm_pretrained', True))
self.to_llm_embed = nn.Linear(64, hidden_size)
self.gpt2 = GPT2Backbone(hidden_size=hidden_size, n_layer=llm_layers, n_head=llm_heads, use_pretrained=use_pretrained)
self.proj_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, self.horizon)
)
def forward(self, x):
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
for blk in self.sp_blocks:
outs.append(blk(x))
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0]
# x: (B, N, T, 64) physical quantities after ODE-based transform
B, N, T, C = x.shape
x = self.to_llm_embed(x) # (B, N, T, H)
x = x.permute(0, 1, 2, 3).contiguous().view(B * N, T, -1) # (B*N, T, H)
llm_hidden = self.gpt2(inputs_embeds=x) # (B*N, T, H)
last_state = llm_hidden[:, -1, :] # (B*N, H)
y = self.proj_head(last_state) # (B*N, horizon)
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1) # (B, horizon, N, 1)
return y

View File

@ -0,0 +1,4 @@
from .STGODE_LLM import ODEGCN_LLM

View File

@ -0,0 +1,145 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.STGODE.odegcn import ODEG
from models.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class GPT2BackboneHF(nn.Module):
def __init__(self, model_name: str | None = None, gradient_checkpointing: bool = False, freeze: bool = False, local_dir: str | None = None):
super().__init__()
from transformers import GPT2Model
if local_dir is not None and len(local_dir) > 0:
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False)
else:
if model_name is None:
model_name = 'gpt2'
self.model = GPT2Model.from_pretrained(model_name, use_cache=False)
if gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.hidden_size = self.model.config.hidden_size
if freeze:
for p in self.model.parameters():
p.requires_grad = False
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
outputs = self.model(inputs_embeds=inputs_embeds)
return outputs.last_hidden_state
class ODEGCN_LLM_GPT2(nn.Module):
def __init__(self, config):
super(ODEGCN_LLM_GPT2, self).__init__()
args = config['model']
num_nodes = config['data']['num_nodes']
num_features = args['num_features']
self.history = args['history']
self.horizon = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(config)
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
self.se_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
# HF GPT-2
gpt2_name = args.get('gpt2_name', 'gpt2')
grad_ckpt = bool(args.get('gpt2_grad_ckpt', False))
gpt2_freeze = bool(args.get('gpt2_freeze', False))
gpt2_local_dir = args.get('gpt2_local_dir', None)
self.gpt2 = GPT2BackboneHF(gpt2_name, gradient_checkpointing=grad_ckpt, freeze=gpt2_freeze, local_dir=gpt2_local_dir)
# Project ODE features to GPT-2 hidden size
self.to_llm_embed = nn.Linear(64, self.gpt2.hidden_size)
# Prediction head
self.proj_head = nn.Sequential(
nn.Linear(self.gpt2.hidden_size, self.gpt2.hidden_size),
nn.ReLU(),
nn.Linear(self.gpt2.hidden_size, self.horizon)
)
def forward(self, x):
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
for blk in self.sp_blocks:
outs.append(blk(x))
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0] # (B, N, T, 64)
B, N, T, C = x.shape
x = self.to_llm_embed(x).view(B * N, T, -1)
llm_hidden = self.gpt2(inputs_embeds=x)
last_state = llm_hidden[:, -1, :]
y = self.proj_head(last_state)
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1)
return y

View File

@ -0,0 +1,4 @@
from .STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2

View File

@ -1,12 +1,16 @@
from models.STDEN.stden_model import STDENModel
from models.STGODE.STGODE import ODEGCN
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
case 'STDEN':
case 'STDEN':
model = STDENModel(config)
case 'STGODE':
model = ODEGCN(config)
return model
case 'STGODE-LLM-GPT2':
model = ODEGCN_LLM_GPT2(config)
return model

Binary file not shown.

43
requirements.txt Normal file
View File

@ -0,0 +1,43 @@
# 核心深度学习框架
torch
torchvision
torchaudio
# 科学计算和数据处理
numpy
pandas
scipy
# 机器学习工具
scikit-learn
# 配置和文件处理
pyyaml
# 进度条
tqdm
# 图神经网络和距离计算
fastdtw
# 微分方程求解器
torchdiffeq
# 自然语言处理用于GPT-2模型
transformers
# 数据可视化
matplotlib
# 网络请求(用于数据下载)
requests
# 文件压缩处理
# Kaggle数据下载
kagglehub
# 其他工具
future

Binary file not shown.

Binary file not shown.

146
utils/download.py Normal file
View File

@ -0,0 +1,146 @@
import os
import requests
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
from tqdm import tqdm
# 定义文件完整性信息的字典
def check_and_download_data():
"""
检查 data 文件夹的完整性并根据缺失文件类型下载相应数据
"""
current_working_dir = os.getcwd() # 获取当前工作目录
data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下
expected_structure = {
"PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"],
"PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"],
"PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"],
"PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 检查根目录下的 get_adj.py 文件
if "get_adj.py" not in os.listdir(data_dir):
# print(f"根目录下缺少文件 get_adj.py。")
missing_adj = True
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue
# 获取子文件夹中的实际文件列表
actual_files = os.listdir(subfolder_path)
# 检查是否缺少文件
for expected_file in expected_files:
if expected_file not in actual_files:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file:
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑
if missing_adj:
download_adj_data(current_dir)
if missing_main_files:
download_kaggle_data(current_dir)
return True
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件并显示下载进度条
如果下载失败最多重试 max_retries
"""
url = "http://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, 'wb') as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。")
def download_kaggle_data(current_dir):
"""
下载 KaggleHub 数据集并将 data 文件夹合并到当前工作目录
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
data_folder_path = os.path.join(path, "data")
if os.path.exists(data_folder_path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 合并文件夹,覆盖冲突的文件
shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True)
# print(f"data 文件夹已合并到: {destination_path}")
# else:
# print("未找到 data 文件夹,跳过合并操作。")
# else:
# print("未找到 KaggleHub 数据集路径,跳过处理。")
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
# 主程序
if __name__ == "__main__":
check_and_download_data()