Compare commits

...

7 Commits

37 changed files with 2554 additions and 4 deletions

5
.gitignore vendored
View File

@ -160,3 +160,8 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
.STDEN/
.data/PEMS08/
exp/
STDEN/
models/gpt2/

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/Project-I.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -0,0 +1,27 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="136" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="8">
<item index="0" class="java.lang.String" itemvalue="argparse" />
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
<item index="5" class="java.lang.String" itemvalue="setuptools" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
</modules>
</component>
</project>

7
.idea/vcs.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

216
.idea/workspace.xml Normal file
View File

@ -0,0 +1,216 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="Git.Settings">
<excluded-from-favorite>
<branch-storage>
<map>
<entry type="LOCAL">
<value>
<list>
<branch-info repo="$PROJECT_DIR$" source="main" />
</list>
</value>
</entry>
</map>
</branch-storage>
</excluded-from-favorite>
<option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="main" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="ROOT_SYNC" value="DONT_SYNC" />
</component>
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
<component name="ProjectColorInfo">{
&quot;associatedIndex&quot;: 3
}</component>
<component name="ProjectId" id="3264JlB7seHXuXCCcdmTyEsXI45" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"Python.STDEN.executor": "Debug",
"Python.STGODE.executor": "Run",
"Python.main.executor": "Run",
"RunOnceActivity.OpenProjectViewOnStart": "true",
"RunOnceActivity.ShowReadmeOnStart": "true",
"git-widget-placeholder": "STGODE",
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
"node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true",
"node.js.selected.package.eslint": "(autodetect)",
"node.js.selected.package.tslint": "(autodetect)",
"nodejs_package_manager_path": "npm",
"vue.rearranger.settings.migration": "true"
}
}]]></component>
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
<layout>
<window_info id="Space Code Reviews" show_stripe_button="false" />
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
<window_info id="Merge Requests" show_stripe_button="false" />
<window_info id="Commit_Guest" show_stripe_button="false" />
<window_info id="Pull Requests" show_stripe_button="false" />
<window_info id="Learn" show_stripe_button="false" />
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
<window_info id="Commit" order="1" weight="0.25" />
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
<window_info anchor="bottom" id="Version Control" order="0" />
<window_info anchor="bottom" id="Problems" order="1" />
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
<window_info anchor="bottom" id="Services" order="4" />
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
<window_info anchor="right" id="SciView" show_stripe_button="false" />
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
<window_info anchor="right" id="Database" order="2" weight="0.25" />
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
<window_info anchor="right" id="Plots" order="5" weight="0.1" />
</layout>
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/trainer" />
<recent name="$PROJECT_DIR$/configs/STDEN" />
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
</component>
<component name="RunManager" selected="Python.STGODE">
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STDEN/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.STDEN" />
<item itemvalue="Python.STGODE" />
</list>
</component>
<component name="SharedIndexes">
<attachedChunks>
<set>
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
</set>
</attachedChunks>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="默认任务">
<changelist id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="" />
<created>1756727620810</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1756727620810</updated>
<workItem from="1756727623101" duration="4721000" />
<workItem from="1756856673845" duration="652000" />
<workItem from="1756864144998" duration="1063000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/models/STDEN/stden_model.py</url>
<line>131</line>
<option name="timeStamp" value="5" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
</component>
</project>

1
STDEN Submodule

@ -0,0 +1 @@
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2

65
configs/STDEN/PEMS08.yaml Normal file
View File

@ -0,0 +1,65 @@
basic:
device: cuda:0
dataset: PEMS08
model: STDEN
mode: train
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 32
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 32
input_dim: 1
lag: 24
horizon: 24
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
l1_decay: 0
seq_len: 12
horizon: 12
input_dim: 1
output_dim: 1
latent_dim: 4
n_traj_samples: 3
ode_method: dopri5
odeint_atol: 0.00001
odeint_rtol: 0.00001
rnn_units: 64
num_rnn_layers: 1
gcn_step: 2
filter_type: default # unkP IncP default
recg_type: gru
save_latent: false
nfe: false
train:
loss: mae
batch_size: 64
epochs: 100
lr_init: 0.003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True

View File

@ -0,0 +1,44 @@
---
log_base_dir: logs/BJ_GM
log_level: INFO
data:
batch_size: 32
dataset_dir: data/BJ_GM
val_batch_size: 32
graph_pkl_filename: data/sensor_graph/adj_GM.npy
model:
l1_decay: 0
seq_len: 12
horizon: 12
input_dim: 1
output_dim: 1
latent_dim: 4
n_traj_samples: 3
ode_method: dopri5
odeint_atol: 0.00001
odeint_rtol: 0.00001
rnn_units: 64
num_rnn_layers: 1
gcn_step: 2
filter_type: default # unkP IncP default
recg_type: gru
save_latent: false
nfe: false
train:
base_lr: 0.01
dropout: 0
load: 0
epoch: 0
epochs: 100
epsilon: 1.0e-3
lr_decay_ratio: 0.1
max_grad_norm: 5
min_learning_rate: 2.0e-06
optimizer: adam
patience: 20
steps: [20, 30, 40, 50]
test_every_n_epochs: 5

View File

@ -0,0 +1,44 @@
---
log_base_dir: logs/BJ_RM
log_level: INFO
data:
batch_size: 32
dataset_dir: data/BJ_RM
val_batch_size: 32
graph_pkl_filename: data/sensor_graph/adj_RM.npy
model:
l1_decay: 0
seq_len: 12
horizon: 12
input_dim: 1
output_dim: 1
latent_dim: 4
n_traj_samples: 3
ode_method: dopri5
odeint_atol: 0.00001
odeint_rtol: 0.00001
rnn_units: 64 # for recognition
num_rnn_layers: 1
gcn_step: 2
filter_type: default # unkP IncP default
recg_type: gru
save_latent: false
nfe: false
train:
base_lr: 0.01
dropout: 0
load: 0 # 0 for not load
epoch: 0
epochs: 100
epsilon: 1.0e-3
lr_decay_ratio: 0.1
max_grad_norm: 5
min_learning_rate: 2.0e-06
optimizer: adam
patience: 20
steps: [20, 30, 40, 50]
test_every_n_epochs: 5

View File

@ -0,0 +1,44 @@
---
log_base_dir: logs/BJ_XZ
log_level: INFO
data:
batch_size: 32
dataset_dir: data/BJ_XZ
val_batch_size: 32
graph_pkl_filename: data/sensor_graph/adj_XZ.npy
model:
l1_decay: 0
seq_len: 12
horizon: 12
input_dim: 1
output_dim: 1
latent_dim: 4
n_traj_samples: 3
ode_method: dopri5
odeint_atol: 0.00001
odeint_rtol: 0.00001
rnn_units: 64
num_rnn_layers: 1
gcn_step: 2
filter_type: default # unkP IncP default
recg_type: gru
save_latent: false
nfe: false
train:
base_lr: 0.01
dropout: 0
load: 0 # 0 for not load
epoch: 0
epochs: 100
epsilon: 1.0e-3
lr_decay_ratio: 0.1
max_grad_norm: 5
min_learning_rate: 2.0e-06
optimizer: adam
patience: 20
steps: [20, 30, 40, 50]
test_every_n_epochs: 5

View File

@ -0,0 +1,60 @@
basic:
device: cuda:0
dataset: PEMS08
model: STGODE
mode: train
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 32
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 64
input_dim: 1
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
rnn_units: 64
sigma1: 0.1
sigma2: 10
thres1: 0.6
thres2: 0.5
train:
loss: mae
batch_size: 64
epochs: 100
lr_init: 0.003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
log_step: 3000

296
data/PEMS08/PEMS08.csv Executable file
View File

@ -0,0 +1,296 @@
from,to,cost
9,153,310.6
153,62,330.9
62,111,332.9
111,11,324.2
11,28,336.0
28,169,133.7
138,135,354.7
135,133,387.9
133,163,337.1
163,20,352.0
20,19,420.8
19,14,351.3
14,39,340.2
39,164,350.3
164,167,365.2
167,70,359.0
70,59,388.2
59,58,305.7
58,67,294.4
67,66,299.5
66,55,313.3
55,53,332.1
53,150,278.9
150,61,308.4
61,64,311.4
64,63,243.6
47,65,372.8
65,48,319.4
48,49,309.7
49,54,320.5
54,56,318.3
56,57,297.9
57,68,293.5
68,69,342.5
69,60,318.0
60,17,305.9
17,5,321.4
5,18,402.2
18,22,447.4
22,30,377.5
30,29,417.7
29,21,360.8
21,132,407.6
132,134,386.9
134,136,350.2
123,121,326.3
121,140,385.2
140,118,393.0
118,96,296.7
96,94,398.2
94,86,337.1
86,78,473.8
78,46,353.4
46,152,385.7
152,157,350.0
157,35,354.4
35,77,356.1
77,52,354.2
52,3,357.8
3,16,382.4
16,0,55.7
42,12,335.1
12,139,328.8
139,168,412.6
168,154,337.3
154,143,370.7
143,10,6.3
107,105,354.6
105,104,386.9
104,148,362.1
148,97,316.3
97,101,380.7
101,137,361.4
137,102,365.5
102,24,375.5
24,166,312.2
129,156,256.1
156,33,329.1
33,32,356.5
91,89,405.6
89,147,347.0
147,15,351.7
15,44,339.5
44,41,350.8
41,43,322.6
43,100,338.9
100,83,347.9
83,87,327.2
87,88,321.0
88,75,335.8
75,51,384.8
51,73,391.1
73,71,289.3
31,155,260.0
155,34,320.4
34,128,393.3
145,115,399.4
115,112,328.1
112,8,469.4
8,117,816.2
117,125,397.1
125,127,372.7
127,109,380.5
109,161,355.5
161,110,367.7
110,160,102.0
72,159,342.9
159,50,383.3
50,74,354.1
74,82,350.2
82,81,335.4
81,99,391.6
99,84,354.9
84,13,306.4
13,40,327.4
40,162,413.9
162,108,301.9
108,146,317.8
146,85,376.6
85,90,347.0
26,27,341.6
27,6,359.4
6,149,417.8
149,126,388.0
126,124,384.3
124,7,763.3
7,114,323.1
114,113,351.6
113,116,411.9
116,144,262.0
25,103,350.2
103,23,376.3
23,165,396.4
165,38,381.0
38,92,368.0
92,37,336.3
37,130,357.8
130,106,532.3
106,131,166.5
1,2,371.6
2,4,338.1
4,76,429.0
76,36,366.1
36,158,344.5
158,151,350.1
151,45,358.8
45,93,340.9
93,80,329.9
80,79,384.1
79,95,335.7
95,98,320.9
98,119,340.3
119,120,376.8
120,122,393.1
122,141,428.7
141,142,359.3
30,165,379.6
165,29,41.7
29,38,343.3
65,72,297.9
72,48,21.5
17,153,375.6
153,5,256.3
153,62,330.9
18,6,499.4
6,22,254.0
22,149,185.4
22,4,257.9
4,30,236.8
30,76,307.0
95,98,320.9
98,144,45.1
45,93,340.9
93,106,112.2
162,151,113.6
151,108,192.9
108,45,359.8
146,92,311.2
92,85,343.9
85,37,373.2
13,169,326.2
169,40,96.1
124,13,460.7
13,7,305.5
7,40,624.1
124,169,145.2
169,7,631.5
90,132,152.2
26,32,106.7
9,129,148.3
129,153,219.6
31,26,116.0
26,155,270.7
9,128,142.2
128,153,215.0
153,167,269.7
167,62,64.8
62,70,332.6
124,169,145.2
169,7,631.5
44,169,397.8
169,41,124.0
44,124,375.7
124,41,243.9
41,7,519.4
6,14,289.3
14,149,259.0
149,39,206.9
144,98,45.1
19,4,326.8
4,14,178.6
14,76,299.0
15,151,136.4
151,44,203.1
45,106,260.6
106,93,112.2
20,165,132.5
165,19,289.2
89,92,323.2
92,147,321.9
147,37,48.2
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
78,107,143.9
107,46,236.3
104,147,277.5
147,148,84.7
20,101,201.2
101,19,534.4
19,137,245.5
8,42,759.5
42,117,58.9
44,42,342.3
42,41,102.5
44,8,789.1
8,41,657.4
41,117,160.5
168,167,172.4
167,154,165.2
143,128,81.9
128,10,88.2
118,145,250.6
145,96,85.1
15,152,135.0
152,44,204.6
19,77,320.7
77,14,299.8
14,52,127.6
14,127,314.8
127,39,280.4
39,109,237.0
31,160,116.5
160,155,272.4
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
32,160,107.7
72,162,3274.4
162,13,554.5
162,40,413.9
65,72,297.9
72,48,21.5
13,42,319.8
42,40,40.7
8,42,759.5
42,117,58.9
8,13,450.3
13,117,378.5
117,40,64.0
46,162,391.6
162,152,115.3
152,108,191.4
104,108,375.9
108,148,311.6
148,146,80.0
21,90,396.9
90,132,152.2
101,29,252.3
29,137,110.7
77,22,353.8
22,52,227.8
52,30,186.6
127,18,425.2
18,109,439.1
109,22,135.5
168,17,232.7
17,154,294.2
154,5,166.3
78,107,143.9
107,46,236.3
118,145,250.6
145,96,85.1
1 from to cost
2 9 153 310.6
3 153 62 330.9
4 62 111 332.9
5 111 11 324.2
6 11 28 336.0
7 28 169 133.7
8 138 135 354.7
9 135 133 387.9
10 133 163 337.1
11 163 20 352.0
12 20 19 420.8
13 19 14 351.3
14 14 39 340.2
15 39 164 350.3
16 164 167 365.2
17 167 70 359.0
18 70 59 388.2
19 59 58 305.7
20 58 67 294.4
21 67 66 299.5
22 66 55 313.3
23 55 53 332.1
24 53 150 278.9
25 150 61 308.4
26 61 64 311.4
27 64 63 243.6
28 47 65 372.8
29 65 48 319.4
30 48 49 309.7
31 49 54 320.5
32 54 56 318.3
33 56 57 297.9
34 57 68 293.5
35 68 69 342.5
36 69 60 318.0
37 60 17 305.9
38 17 5 321.4
39 5 18 402.2
40 18 22 447.4
41 22 30 377.5
42 30 29 417.7
43 29 21 360.8
44 21 132 407.6
45 132 134 386.9
46 134 136 350.2
47 123 121 326.3
48 121 140 385.2
49 140 118 393.0
50 118 96 296.7
51 96 94 398.2
52 94 86 337.1
53 86 78 473.8
54 78 46 353.4
55 46 152 385.7
56 152 157 350.0
57 157 35 354.4
58 35 77 356.1
59 77 52 354.2
60 52 3 357.8
61 3 16 382.4
62 16 0 55.7
63 42 12 335.1
64 12 139 328.8
65 139 168 412.6
66 168 154 337.3
67 154 143 370.7
68 143 10 6.3
69 107 105 354.6
70 105 104 386.9
71 104 148 362.1
72 148 97 316.3
73 97 101 380.7
74 101 137 361.4
75 137 102 365.5
76 102 24 375.5
77 24 166 312.2
78 129 156 256.1
79 156 33 329.1
80 33 32 356.5
81 91 89 405.6
82 89 147 347.0
83 147 15 351.7
84 15 44 339.5
85 44 41 350.8
86 41 43 322.6
87 43 100 338.9
88 100 83 347.9
89 83 87 327.2
90 87 88 321.0
91 88 75 335.8
92 75 51 384.8
93 51 73 391.1
94 73 71 289.3
95 31 155 260.0
96 155 34 320.4
97 34 128 393.3
98 145 115 399.4
99 115 112 328.1
100 112 8 469.4
101 8 117 816.2
102 117 125 397.1
103 125 127 372.7
104 127 109 380.5
105 109 161 355.5
106 161 110 367.7
107 110 160 102.0
108 72 159 342.9
109 159 50 383.3
110 50 74 354.1
111 74 82 350.2
112 82 81 335.4
113 81 99 391.6
114 99 84 354.9
115 84 13 306.4
116 13 40 327.4
117 40 162 413.9
118 162 108 301.9
119 108 146 317.8
120 146 85 376.6
121 85 90 347.0
122 26 27 341.6
123 27 6 359.4
124 6 149 417.8
125 149 126 388.0
126 126 124 384.3
127 124 7 763.3
128 7 114 323.1
129 114 113 351.6
130 113 116 411.9
131 116 144 262.0
132 25 103 350.2
133 103 23 376.3
134 23 165 396.4
135 165 38 381.0
136 38 92 368.0
137 92 37 336.3
138 37 130 357.8
139 130 106 532.3
140 106 131 166.5
141 1 2 371.6
142 2 4 338.1
143 4 76 429.0
144 76 36 366.1
145 36 158 344.5
146 158 151 350.1
147 151 45 358.8
148 45 93 340.9
149 93 80 329.9
150 80 79 384.1
151 79 95 335.7
152 95 98 320.9
153 98 119 340.3
154 119 120 376.8
155 120 122 393.1
156 122 141 428.7
157 141 142 359.3
158 30 165 379.6
159 165 29 41.7
160 29 38 343.3
161 65 72 297.9
162 72 48 21.5
163 17 153 375.6
164 153 5 256.3
165 153 62 330.9
166 18 6 499.4
167 6 22 254.0
168 22 149 185.4
169 22 4 257.9
170 4 30 236.8
171 30 76 307.0
172 95 98 320.9
173 98 144 45.1
174 45 93 340.9
175 93 106 112.2
176 162 151 113.6
177 151 108 192.9
178 108 45 359.8
179 146 92 311.2
180 92 85 343.9
181 85 37 373.2
182 13 169 326.2
183 169 40 96.1
184 124 13 460.7
185 13 7 305.5
186 7 40 624.1
187 124 169 145.2
188 169 7 631.5
189 90 132 152.2
190 26 32 106.7
191 9 129 148.3
192 129 153 219.6
193 31 26 116.0
194 26 155 270.7
195 9 128 142.2
196 128 153 215.0
197 153 167 269.7
198 167 62 64.8
199 62 70 332.6
200 124 169 145.2
201 169 7 631.5
202 44 169 397.8
203 169 41 124.0
204 44 124 375.7
205 124 41 243.9
206 41 7 519.4
207 6 14 289.3
208 14 149 259.0
209 149 39 206.9
210 144 98 45.1
211 19 4 326.8
212 4 14 178.6
213 14 76 299.0
214 15 151 136.4
215 151 44 203.1
216 45 106 260.6
217 106 93 112.2
218 20 165 132.5
219 165 19 289.2
220 89 92 323.2
221 92 147 321.9
222 147 37 48.2
223 133 91 152.8
224 91 163 313.6
225 150 71 221.1
226 71 61 89.6
227 78 107 143.9
228 107 46 236.3
229 104 147 277.5
230 147 148 84.7
231 20 101 201.2
232 101 19 534.4
233 19 137 245.5
234 8 42 759.5
235 42 117 58.9
236 44 42 342.3
237 42 41 102.5
238 44 8 789.1
239 8 41 657.4
240 41 117 160.5
241 168 167 172.4
242 167 154 165.2
243 143 128 81.9
244 128 10 88.2
245 118 145 250.6
246 145 96 85.1
247 15 152 135.0
248 152 44 204.6
249 19 77 320.7
250 77 14 299.8
251 14 52 127.6
252 14 127 314.8
253 127 39 280.4
254 39 109 237.0
255 31 160 116.5
256 160 155 272.4
257 133 91 152.8
258 91 163 313.6
259 150 71 221.1
260 71 61 89.6
261 32 160 107.7
262 72 162 3274.4
263 162 13 554.5
264 162 40 413.9
265 65 72 297.9
266 72 48 21.5
267 13 42 319.8
268 42 40 40.7
269 8 42 759.5
270 42 117 58.9
271 8 13 450.3
272 13 117 378.5
273 117 40 64.0
274 46 162 391.6
275 162 152 115.3
276 152 108 191.4
277 104 108 375.9
278 108 148 311.6
279 148 146 80.0
280 21 90 396.9
281 90 132 152.2
282 101 29 252.3
283 29 137 110.7
284 77 22 353.8
285 22 52 227.8
286 52 30 186.6
287 127 18 425.2
288 18 109 439.1
289 109 22 135.5
290 168 17 232.7
291 17 154 294.2
292 154 5 166.3
293 78 107 143.9
294 107 46 236.3
295 118 145 250.6
296 145 96 85.1

BIN
data/PEMS08/PEMS08.npz Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,6 +2,7 @@ import numpy as np
import os
def load_dataset(config):
dataset_name = config['basic']['dataset']
node_num = config['data']['num_nodes']
input_dim = config['data']['input_dim']
@ -10,4 +11,8 @@ def load_dataset(config):
case 'EcoSolar':
data_path = os.path.join('./data/EcoSolar.npy')
data = np.load(data_path)[:, :node_num, :input_dim]
case 'PEMS08':
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
data = np.load(data_path)['data'][:, :node_num, :input_dim]
return data

9
data/graph_loader.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
def load_graph(config):
dataset_path = config['data']['graph_pkl_filename']
graph = np.load(dataset_path)
# 将inf值填充为0
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
return graph

View File

@ -3,8 +3,6 @@
时空数据深度学习预测项目主程序
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
"""
import os
from utils.args_reader import config_loader
import utils.init as init
import torch

View File

@ -0,0 +1,88 @@
### STDEN 模块与执行流(缩进层级表)
模块 | 类/函数 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | STDENModel.forward | inputs: (seq_len, batch_size, num_edges x input_dim) | outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float)
1.1 | Encoder_z0_RNN.forward | (seq_len, batch_size, num_edges x input_dim) | mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim)
1.1.1 | utils.sample_standard_gaussian | mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 | z0: (n_traj, batch, num_nodes x latent_dim)
1.2 | DiffeqSolver.forward | first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) | sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float)
1.2.1 | ODEFunc.forward | t_local: 标量/1D; y: (B, num_nodes x latent_dim) | dy/dt: (B, num_nodes x latent_dim)
1.3 | Decoder.forward | (horizon, n_traj, batch, num_nodes x latent_dim) | (horizon, batch, num_edges x output_dim)
---
### 细节模块 — Encoder_z0_RNN
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重塑到边批 | (seq_len, batch, num_edges x input_dim) | (seq_len, batch, num_edges, input_dim)
2 | 合并边到批 | (seq_len, batch, num_edges, input_dim) | (seq_len, batch x num_edges, input_dim)
3 | GRU 序列编码 | 同上 | (seq_len, batch x num_edges, rnn_units)
4 | 取最后时间步 | 同上 | (batch x num_edges, rnn_units)
5 | 还原边维 | (batch x num_edges, rnn_units) | (batch, num_edges, rnn_units)
6 | 转置 + 边→节点映射 | (batch, num_edges, rnn_units) 经 inv_grad | (batch, num_nodes, rnn_units)
7 | 全连接映射到 2x latent | (batch, num_nodes, rnn_units) | (batch, num_nodes, 2 x latent_dim)
8 | 拆分均值/标准差 | 同上 | mean/std: (batch, num_nodes, latent_dim)
9 | 展平并加时间维 | (batch, num_nodes, latent_dim) | (1, batch, num_nodes x latent_dim)
备注inv_grad 来源于 `utils.graph_grad(adj).T` 并做缩放;`hiddens_to_z0` 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。
---
### 细节模块 — 采样utils.sample_standard_gaussian
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重复到 n_traj | mean/std: (1, batch, N·Z) → 重复 | (n_traj, batch, N·Z)
2 | 重参数化采样 | mu, sigma | z0: (n_traj, batch, N·Z)
其中 N·Z = num_nodes x latent_dim。
---
### 细节模块 — DiffeqSolver含 ODEFunc 调用)
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 合并样本维度 | first_point: (n_traj, batch, N·Z) | (n_traj x batch, N·Z)
2 | ODE 积分 | t: (horizon,), y0 | pred_y: (horizon, n_traj x batch, N·Z)
3 | 还原维度 | 同上 | (horizon, n_traj, batch, N·Z)
4 | 统计代价 | odefunc.nfe, elapsed_time | fe: (nfe:int, time:float)
ODEFunc 默认filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。
---
### 细节模块 — ODEFunc默认扩散过程
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 形状整理 | y: (B, N·Z) → (B, N, Z) | (B, N, Z)
2 | 多阶图卷积 _gconv | (B, N, Z) | (B, N, Z') 按需设置 Z'(通常保持 Z
3 | 门控 θ | _gconv(..., output=latent_dim) → Sigmoid | θ: (B, N·Z)
4 | 生成场 ode_func_net | 堆叠 _gconv + 激活 | f(y): (B, N·Z)
5 | 右端梯度 | - θ ⊙ f(y) | dy/dt: (B, N·Z)
说明:
- 支撑矩阵来自 `utils.calculate_random_walk_matrix(adj)`(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。
- 若 `filter_type="unkP"`,则使用 `create_net` 的全连接网络在节点域逐点计算梯度。
---
### 细节模块 — Decoder
步骤 | 操作 | 输入 (shape) | 输出 (shape)
--- | --- | --- | ---
1 | 重塑到节点域 | (T, S, B, N·Z) → (T, S, B, N, Z) | (T, S, B, N, Z)
2 | 节点→边映射 | 乘以 graph_grad (N, E) | (T, S, B, Z, E)
3 | 轨迹与通道均值 | 对 S 和 Z 维做均值 | (T, B, E)
4 | 展平到输出维 | 考虑 output_dim通常为 1 | (T, B, E x output_dim)
符号T=horizonS=n_traj_samplesN=num_nodesE=num_edgesZ=latent_dimB=batch。
---
### 备注与约定
- 内部采用边展平后的时序输入:`(seq_len, batch, num_edges x input_dim)`。
- 图算子:`utils.graph_grad(adj)` 形状 `(N, E)``utils.calculate_random_walk_matrix(adj)` 生成随机游走稀疏矩阵用于图卷积。
- 关键超参数(由配置传入):`latent_dim`, `rnn_units`, `gcn_step`, `n_traj_samples`, `ode_method`, `horizon`, `input_dim`, `output_dim`

View File

@ -0,0 +1,49 @@
import torch
import torch.nn as nn
import time
from torchdiffeq import odeint
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class DiffeqSolver(nn.Module):
def __init__(self, odefunc, method, latent_dim,
odeint_rtol = 1e-4, odeint_atol = 1e-5):
nn.Module.__init__(self)
self.ode_method = method
self.odefunc = odefunc
self.latent_dim = latent_dim
self.rtol = odeint_rtol
self.atol = odeint_atol
def forward(self, first_point, time_steps_to_pred):
"""
Decoder the trajectory through the ODE Solver.
:param time_steps_to_pred: horizon
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
"""
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
start_time = time.time()
self.odefunc.nfe = 0
pred_y = odeint(self.odefunc,
first_point,
time_steps_to_pred,
rtol=self.rtol,
atol=self.atol,
method=self.ode_method)
time_fe = time.time() - start_time
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
# assert(pred_y.size()[1] == n_traj_samples)
# assert(pred_y.size()[2] == batch_size)
return pred_y, (self.odefunc.nfe, time_fe)

165
models/STDEN/ode_func.py Normal file
View File

@ -0,0 +1,165 @@
import numpy as np
import torch
import torch.nn as nn
from models.STDEN import utils
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class LayerParams:
def __init__(self, rnn_network: nn.Module, layer_type: str):
self._rnn_network = rnn_network
self._params_dict = {}
self._biases_dict = {}
self._type = layer_type
def get_weights(self, shape):
if shape not in self._params_dict:
nn_param = nn.Parameter(torch.empty(*shape, device=device))
nn.init.xavier_normal_(nn_param)
self._params_dict[shape] = nn_param
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
nn_param)
return self._params_dict[shape]
def get_biases(self, length, bias_start=0.0):
if length not in self._biases_dict:
biases = nn.Parameter(torch.empty(length, device=device))
nn.init.constant_(biases, bias_start)
self._biases_dict[length] = biases
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
biases)
return self._biases_dict[length]
class ODEFunc(nn.Module):
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
gen_layers=1, nonlinearity='tanh', filter_type="default"):
"""
:param num_units: dimensionality of the hidden layers
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
:param adj_mx:
:param gcn_step:
:param num_nodes:
:param gen_layers: hidden layers in each ode func.
:param nonlinearity:
:param filter_type: default
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
"""
super(ODEFunc, self).__init__()
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
self._num_nodes = num_nodes
self._num_units = num_units # hidden dimension
self._latent_dim = latent_dim
self._gen_layers = gen_layers
self.nfe = 0
self._filter_type = filter_type
if(self._filter_type == "unkP"):
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
utils.init_network_weights(ode_func_net)
self.gradient_net = ode_func_net
else:
self._gcn_step = gcn_step
self._gconv_params = LayerParams(self, 'gconv')
self._supports = []
supports = []
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
for support in supports:
self._supports.append(self._build_sparse_matrix(support))
@staticmethod
def _build_sparse_matrix(L):
L = L.tocoo()
indices = np.column_stack((L.row, L.col))
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
return L
def forward(self, t_local, y, backwards = False):
"""
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
t_local: current time point
y: value at the current time point, shape (B, num_nodes * latent_dim)
:return
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
"""
self.nfe += 1
grad = self.get_ode_gradient_nn(t_local, y)
if backwards:
grad = -grad
return grad
def get_ode_gradient_nn(self, t_local, inputs):
if(self._filter_type == "unkP"):
grad = self._fc(inputs)
elif (self._filter_type == "IncP"):
grad = - self.ode_func_net(inputs)
else: # default is diffusion process
# theta shape: (B, num_nodes * latent_dim)
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
grad = - theta * self.ode_func_net(inputs)
return grad
def ode_func_net(self, inputs):
c = inputs
for i in range(self._gen_layers):
c = self._gconv(c, self._num_units)
c = self._activation(c)
c = self._gconv(c, self._latent_dim)
c = self._activation(c)
return c
def _fc(self, inputs):
batch_size = inputs.size()[0]
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
@staticmethod
def _concat(x, x_):
x_ = x_.unsqueeze(0)
return torch.cat([x, x_], dim=0)
def _gconv(self, inputs, output_size, bias_start=0.0):
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
batch_size = inputs.shape[0]
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
# inputs_and_state = torch.cat([inputs, state], dim=2)
input_size = inputs.size(2)
x = inputs
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
x = torch.unsqueeze(x0, 0)
if self._gcn_step == 0:
pass
else:
for support in self._supports:
x1 = torch.sparse.mm(support, x0)
x = self._concat(x, x1)
for k in range(2, self._gcn_step + 1):
x2 = 2 * torch.sparse.mm(support, x1) - x0
x = self._concat(x, x2)
x1, x0 = x2, x1
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
biases = self._gconv_params.get_biases(output_size, bias_start)
x += biases
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
return torch.reshape(x, [batch_size, self._num_nodes * output_size])

181
models/STDEN/stden_model.py Normal file
View File

@ -0,0 +1,181 @@
import torch
import torch.nn as nn
from torch.nn.modules.rnn import GRU
from models.STDEN.ode_func import ODEFunc
from models.STDEN.diffeq_solver import DiffeqSolver
from models.STDEN import utils
from data.graph_loader import load_graph
class EncoderAttrs:
"""编码器属性配置类"""
def __init__(self, config, adj_mx):
self.adj_mx = adj_mx
self.num_nodes = adj_mx.shape[0]
self.num_edges = (adj_mx > 0.).sum()
self.gcn_step = int(config.get('gcn_step', 2))
self.filter_type = config.get('filter_type', 'default')
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
self.rnn_units = int(config.get('rnn_units'))
self.latent_dim = int(config.get('latent_dim', 4))
class STDENModel(nn.Module, EncoderAttrs):
"""STDEN主模型时空微分方程网络"""
def __init__(self, config):
nn.Module.__init__(self)
adj_mx = load_graph(config)
EncoderAttrs.__init__(self, config['model'], adj_mx)
# 识别网络
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
model_kwargs = config['model']
# ODE求解器配置
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
# 创建ODE函数和求解器
odefunc = ODEFunc(
self.ode_gen_dim, self.latent_dim, adj_mx,
self.gcn_step, self.num_nodes, filter_type=self.filter_type
)
self.diffeq_solver = DiffeqSolver(
odefunc, self.ode_method, self.latent_dim,
odeint_rtol=self.rtol, odeint_atol=self.atol
)
# 潜在特征保存设置
self.save_latent = bool(model_kwargs.get('save_latent', False))
self.latent_feat = None
# 解码器
self.horizon = int(model_kwargs.get('horizon', 1))
self.out_feat = int(model_kwargs.get('output_dim', 1))
self.decoder = Decoder(
self.out_feat, adj_mx, self.num_nodes, self.num_edges
)
def forward(self, inputs, labels=None, batches_seen=None):
"""
seq2seq前向传播
:param inputs: (seq_len, batch_size, num_edges * input_dim)
:param labels: (horizon, batch_size, num_edges * output_dim)
:param batches_seen: 已见批次数量
:return: outputs: (horizon, batch_size, num_edges * output_dim)
"""
# 编码初始潜在状态
B, T, N, C = inputs.shape
inputs = inputs.view(T, B, N * C)
first_point_mu, first_point_std = self.encoder_z0(inputs)
# 采样轨迹
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
# 时间步预测
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
# ODE求解
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
if self.save_latent:
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
# 解码输出
outputs = self.decoder(sol_ys)
outputs = outputs.view(B, T, N, C)
return outputs, fe
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
"""RNN编码器将输入序列编码为初始潜在状态"""
def __init__(self, config, adj_mx):
nn.Module.__init__(self)
EncoderAttrs.__init__(self, config, adj_mx)
self.recg_type = config.get('recg_type', 'gru')
self.input_dim = int(config.get('input_dim', 1))
if self.recg_type == 'gru':
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
else:
raise NotImplementedError("只支持'gru'识别网络")
# 隐藏状态到z0的映射
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
self.inv_grad[self.inv_grad != 0.] = 0.5
self.hiddens_to_z0 = nn.Sequential(
nn.Linear(self.rnn_units, 50),
nn.Tanh(),
nn.Linear(50, self.latent_dim * 2)
)
utils.init_network_weights(self.hiddens_to_z0)
def forward(self, inputs):
"""
编码器前向传播
:param inputs: (seq_len, batch_size, num_edges * input_dim)
:return: mean, std: (1, batch_size, latent_dim)
"""
seq_len, batch_size = inputs.size(0), inputs.size(1)
# 重塑输入并处理
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
# GRU处理
outputs, _ = self.gru_rnn(inputs)
last_output = outputs[-1]
# 重塑并转换维度
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
last_output = torch.transpose(last_output, (-2, -1))
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
# 生成均值和标准差
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
mean = mean.reshape(batch_size, -1)
std = std.reshape(batch_size, -1).abs()
return mean.unsqueeze(0), std.unsqueeze(0)
class Decoder(nn.Module):
"""解码器:将潜在状态解码为输出"""
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
super(Decoder, self).__init__()
self.num_nodes = num_nodes
self.num_edges = num_edges
self.grap_grad = utils.graph_grad(adj_mx)
self.output_dim = output_dim
def forward(self, inputs):
"""
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
:return: outputs: (horizon, batch_size, num_edges * output_dim)
"""
horizon, n_traj_samples, batch_size = inputs.size()[:3]
# 重塑输入
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
latent_dim = inputs.size(-2)
# 图梯度变换:从节点到边
outputs = torch.matmul(inputs, self.grap_grad)
# 重塑并平均采样轨迹
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
outputs = outputs.reshape(horizon, batch_size, -1)
return outputs

234
models/STDEN/utils.py Normal file
View File

@ -0,0 +1,234 @@
import logging
import numpy as np
import os
import time
import scipy.sparse as sp
import sys
import torch
import torch.nn as nn
class DataLoader(object):
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
"""
:param xs:
:param ys:
:param batch_size:
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
"""
self.batch_size = batch_size
self.current_ind = 0
if pad_with_last_sample:
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
xs = np.concatenate([xs, x_padding], axis=0)
ys = np.concatenate([ys, y_padding], axis=0)
self.size = len(xs)
self.num_batch = int(self.size // self.batch_size)
if shuffle:
permutation = np.random.permutation(self.size)
xs, ys = xs[permutation], ys[permutation]
self.xs = xs
self.ys = ys
def get_iterator(self):
self.current_ind = 0
def _wrapper():
while self.current_ind < self.num_batch:
start_ind = self.batch_size * self.current_ind
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
x_i = self.xs[start_ind: end_ind, ...]
y_i = self.ys[start_ind: end_ind, ...]
yield (x_i, y_i)
self.current_ind += 1
return _wrapper()
class StandardScaler:
"""
Standard the input
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def transform(self, data):
return (data - self.mean) / self.std
def inverse_transform(self, data):
return (data * self.std) + self.mean
def calculate_random_walk_matrix(adj_mx):
adj_mx = sp.coo_matrix(adj_mx)
d = np.array(adj_mx.sum(1))
d_inv = np.power(d, -1).flatten()
d_inv[np.isinf(d_inv)] = 0.
d_mat_inv = sp.diags(d_inv)
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
return random_walk_mx
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
# Create the log directory if necessary.
try:
os.makedirs(log_dir)
except OSError:
pass
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
file_handler.setLevel(level=level)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
console_handler.setLevel(level=level)
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
logger = logging.getLogger(name)
logger.setLevel(level)
# Add file handler and stdout handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
file_handler.setFormatter(formatter)
# Add console handler.
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(console_formatter)
logger.addHandler(file_handler)
logger.addHandler(console_handler)
# Add google cloud log handler
logger.info('Log directory: %s', log_dir)
return logger
def get_log_dir(kwargs):
log_dir = kwargs['train'].get('log_dir')
if log_dir is None:
batch_size = kwargs['data'].get('batch_size')
filter_type = kwargs['model'].get('filter_type')
gcn_step = kwargs['model'].get('gcn_step')
horizon = kwargs['model'].get('horizon')
latent_dim = kwargs['model'].get('latent_dim')
n_traj_samples = kwargs['model'].get('n_traj_samples')
ode_method = kwargs['model'].get('ode_method')
seq_len = kwargs['model'].get('seq_len')
rnn_units = kwargs['model'].get('rnn_units')
recg_type = kwargs['model'].get('recg_type')
if filter_type == 'unkP':
filter_type_abbr = 'UP'
elif filter_type == 'IncP':
filter_type_abbr = 'NV'
else:
filter_type_abbr = 'DF'
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size,
seq_len, horizon, time.strftime('%m%d%H%M%S'))
base_dir = kwargs.get('log_base_dir')
log_dir = os.path.join(base_dir, run_id)
if not os.path.exists(log_dir):
os.makedirs(log_dir)
return log_dir
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
if ('BJ' in dataset_dir):
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
for category in ['train', 'val', 'test']:
data['x_' + category] = data['x_' + category] # [..., :4] # ignore the time index
else:
data = {}
for category in ['train', 'val', 'test']:
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
data['x_' + category] = cat_data['x']
data['y_' + category] = cat_data['y']
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
# Data format
for category in ['train', 'val', 'test']:
data['x_' + category] = scaler.transform(data['x_' + category])
data['y_' + category] = scaler.transform(data['y_' + category])
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
data['scaler'] = scaler
return data
def load_graph_data(pkl_filename):
adj_mx = np.load(pkl_filename)
return adj_mx
def graph_grad(adj_mx):
"""Fetch the graph gradient operator."""
num_nodes = adj_mx.shape[0]
num_edges = (adj_mx > 0.).sum()
grad = torch.zeros(num_nodes, num_edges)
e = 0
for i in range(num_nodes):
for j in range(num_nodes):
if adj_mx[i, j] == 0:
continue
grad[i, e] = 1.
grad[j, e] = -1.
e += 1
return grad
def init_network_weights(net, std=0.1):
"""
Just for nn.Linear net.
"""
for m in net.modules():
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, mean=0, std=std)
nn.init.constant_(m.bias, val=0)
def split_last_dim(data):
last_dim = data.size()[-1]
last_dim = last_dim // 2
res = data[..., :last_dim], data[..., last_dim:]
return res
def get_device(tensor):
device = torch.device("cpu")
if tensor.is_cuda:
device = tensor.get_device()
return device
def sample_standard_gaussian(mu, sigma):
device = get_device(mu)
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
r = d.sample(mu.size()).squeeze(-1)
return r * sigma.float() + mu.float()
def create_net(n_inputs, n_outputs, n_layers=0, n_units=100, nonlinear=nn.Tanh):
layers = [nn.Linear(n_inputs, n_units)]
for i in range(n_layers):
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_units))
layers.append(nonlinear())
layers.append(nn.Linear(n_units, n_outputs))
return nn.Sequential(*layers)

180
models/STGODE/STGODE.py Executable file
View File

@ -0,0 +1,180 @@
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from models.STGODE.odegcn import ODEG
from models.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
"""
extra dimension will be added by padding, remove it
"""
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
"""
time dilation convolution
"""
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
"""
Args:
num_inputs : channel's number of input data's feature
num_channels : numbers of data feature tranform channels, the last is the output channel
kernel_size : using 1d convolution, so the real kernel is (1, kernel_size)
"""
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
"""
like ResNet
Args:
X : input data of shape (B, N, T, F)
"""
# permute shape to (B, F, N, T)
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class GCN(nn.Module):
def __init__(self, A_hat, in_channels, out_channels, ):
super(GCN, self).__init__()
self.A_hat = A_hat
self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
self.reset()
def reset(self):
stdv = 1. / math.sqrt(self.theta.shape[1])
self.theta.data.uniform_(-stdv, stdv)
def forward(self, X):
y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X)
return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta))
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
"""
Args:
in_channels: Number of input features at each node in each time step.
out_channels: a list of feature channels in timeblock, the last is output feature channel
num_nodes: Number of nodes in the graph
A_hat: the normalized adjacency matrix
"""
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels,
num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1],
num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
"""
Args:
X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features)
Return:
Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1])
"""
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class ODEGCN(nn.Module):
""" the overall network framework """
def __init__(self, config):
"""
Args:
num_nodes : number of nodes in the graph
num_features : number of features at each node in each time step
num_timesteps_input : number of past time steps fed into the network
num_timesteps_output : desired number of future time steps output by the network
A_sp_hat : nomarlized adjacency spatial matrix
A_se_hat : nomarlized adjacency semantic matrix
"""
super(ODEGCN, self).__init__()
args = config['model']
num_nodes = config['data']['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(config)
# spatial graph
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
# semantic graph
self.se_blocks = nn.ModuleList([nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
self.pred = nn.Sequential(
nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32),
nn.ReLU(),
nn.Linear(num_timesteps_output * 32, num_timesteps_output)
)
def forward(self, x):
"""
Args:
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
Returns:
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
"""
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
# spatial graph
for blk in self.sp_blocks:
outs.append(blk(x))
# semantic graph
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0]
x = x.reshape((x.shape[0], x.shape[1], -1))
return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)

132
models/STGODE/adj.py Executable file
View File

@ -0,0 +1,132 @@
import csv
import os
import pandas as pd
import numpy as np
from fastdtw import fastdtw
from tqdm import tqdm
import torch
files = {
358: ['PEMS03/PEMS03.npz', 'PEMS03/PEMS03.csv'],
307: ['PEMS04/PEMS04.npz', 'PEMS04/PEMS04.csv'],
883: ['PEMS07/PEMS07.npz', 'PEMS07/PEMS07.csv'],
170: ['PEMS08/PEMS08.npz', 'PEMS08/PEMS08.csv'],
# 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
# 'pemsD7M': ['PEMSD7M/PEMSD7M.npz', 'PEMSD7M/distance.csv'],
# 'pemsD7L': ['PEMSD7L/PEMSD7L.npz', 'PEMSD7L/distance.csv']
}
def get_A_hat(config):
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
Args:
sigma1: float, default=0.1, sigma for the semantic matrix
sigma2: float, default=10, sigma for the spatial matrix
thres1: float, default=0.6, the threshold for the semantic matrix
thres2: float, default=0.5, the threshold for the spatial matrix
Returns:
data: tensor, T * N * 1
dtw_matrix: array, semantic adjacency matrix
sp_matrix: array, spatial adjacency matrix
"""
file_path = config['data']['graph_pkl_filename']
filename = config['basic']['dataset']
dataset_path = config['data']['dataset_dir']
args = config['model']
data = np.load(file_path)
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
num_node = data.shape[1]
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
data = (data - mean_value) / std_value
# 计算dtw_distance, 如果存在缓存则直接读取缓存
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'):
data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
axis=0)
data_mean = data_mean.squeeze().T
dtw_distance = np.zeros((num_node, num_node))
for i in tqdm(range(num_node)):
for j in range(i, num_node):
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
for i in range(num_node):
for j in range(i):
dtw_distance[i][j] = dtw_distance[j][i]
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance)
dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy')
mean = np.mean(dist_matrix)
std = np.std(dist_matrix)
dist_matrix = (dist_matrix - mean) / std
sigma = args['sigma1']
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
dtw_matrix = np.zeros_like(dist_matrix)
dtw_matrix[dist_matrix > args['thres1']] = 1
# 计算spatial_distance, 如果存在缓存则直接读取缓存
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
if num_node == 358:
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(id_dict[row[0]])
end = int(id_dict[row[1]])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
else:
# 使用 pandas 读取 CSV 文件,跳过标题行
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
for _, row in df.iterrows():
start = int(row[0])
end = int(row[1])
dist_matrix[start][end] = float(row[2])
dist_matrix[end][start] = float(row[2])
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
# normalization
std = np.std(dist_matrix[dist_matrix != float('inf')])
mean = np.mean(dist_matrix[dist_matrix != float('inf')])
dist_matrix = (dist_matrix - mean) / std
sigma = args['sigma2']
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
sp_matrix[sp_matrix < args['thres2']] = 0
return (get_normalized_adj(dtw_matrix).to(config['basic']['device']),
get_normalized_adj(sp_matrix).to(config['basic']['device']))
def get_normalized_adj(A):
"""
Returns a tensor, the degree normalized adjacency matrix.
"""
alpha = 0.8
D = np.array(np.sum(A, axis=1)).reshape((-1,))
D[D <= 10e-5] = 10e-5 # Prevent infs
diag = np.reciprocal(np.sqrt(D))
A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
diag.reshape((1, -1)))
A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave)
return torch.from_numpy(A_reg.astype(np.float32))
if __name__ == '__main__':
config = {
'sigma1': 0.1,
'sigma2': 10,
'thres1': 0.6,
'thres2': 0.5,
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
}
for nodes in [358, 170, 883]:
args = {'num_nodes': nodes, **config}
get_A_hat(args)

74
models/STGODE/odegcn.py Executable file
View File

@ -0,0 +1,74 @@
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
# Whether use adjoint method or not.
adjoint = False
if adjoint:
from torchdiffeq import odeint_adjoint as odeint
else:
from torchdiffeq import odeint
# Define the ODE function.
# Input:
# --- t: A tensor with shape [], meaning the current time.
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
# Output:
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
class ODEFunc(nn.Module):
def __init__(self, feature_dim, temporal_dim, adj):
super(ODEFunc, self).__init__()
self.adj = adj
self.x0 = None
self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1]))
self.beta = 0.6
self.w = nn.Parameter(torch.eye(feature_dim))
self.d = nn.Parameter(torch.zeros(feature_dim) + 1)
self.w2 = nn.Parameter(torch.eye(temporal_dim))
self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1)
def forward(self, t, x):
alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
xa = torch.einsum('ij, kjlm->kilm', self.adj, x)
# ensure the eigenvalues to be less than 1
d = torch.clamp(self.d, min=0, max=1)
w = torch.mm(self.w * d, torch.t(self.w))
xw = torch.einsum('ijkl, lm->ijkm', x, w)
d2 = torch.clamp(self.d2, min=0, max=1)
w2 = torch.mm(self.w2 * d2, torch.t(self.w2))
xw2 = torch.einsum('ijkl, km->ijml', x, w2)
f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0
return f
class ODEblock(nn.Module):
def __init__(self, odefunc, t=torch.tensor([0,1])):
super(ODEblock, self).__init__()
self.t = t
self.odefunc = odefunc
def set_x0(self, x0):
self.odefunc.x0 = x0.clone().detach()
def forward(self, x):
t = self.t.type_as(x)
z = odeint(self.odefunc, x, t, method='euler')[1]
return z
# Define the ODEGCN model.
class ODEG(nn.Module):
def __init__(self, feature_dim, temporal_dim, adj, time):
super(ODEG, self).__init__()
self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj), t=torch.tensor([0, time]))
def forward(self, x):
self.odeblock.set_x0(x)
z = self.odeblock(x)
return F.relu(z)

View File

@ -1,6 +1,12 @@
from models.STDEN.stden_model import STDENModel
from models.STGODE.STGODE import ODEGCN
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
case 'STDEN':
model = STDENModel(config)
case 'STGODE':
model = ODEGCN(config)
return model

Binary file not shown.

Binary file not shown.

BIN
test_semantic.npy Normal file

Binary file not shown.

BIN
test_spatial.npy Normal file

Binary file not shown.

576
trainer/ode_trainer.py Normal file
View File

@ -0,0 +1,576 @@
import math
import os
import time
import copy
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
class Trainer:
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
scalers, logger, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scalers = scalers # 现在是多个标准化器的列表
self.args = config['train']
self.logger = logger
self.args['device'] = config['basic']['device']
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
# 用于收集nfe数据
self.c = []
self.res, self.keys = [], []
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
# 清空nfe数据收集
if mode == 'train':
self.c.clear()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
label = target[..., :self.args['output_dim']]
output, fe = self.model(data)
if self.args['real_value']:
# 只对输出维度进行反归一化
output = self._inverse_transform_output(output)
loss = self.loss(output, label)
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.c.append([*fe, loss.item()])
# 记录FE信息
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
self.keys.append(epoch)
return avg_loss
def _inverse_transform_output(self, output):
"""
只对输出维度进行反归一化
假设输出数据形状为 [batch, horizon, nodes, features]
只对前output_dim个特征进行反归一化
"""
if not self.args['real_value']:
return output
# 获取输出维度的数量
output_dim = self.args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(self.scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(self.scalers):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(self.scalers)):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
self.logger.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
# 保存nfe数据如果启用
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
self._save_nfe_data()
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
self.model.load_state_dict(best_test_model)
self.logger.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
@staticmethod
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
# 用于收集nfe数据
c = []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
output, fe = model(data)
y_pred.append(output)
y_true.append(label)
# 收集nfe数据
c.append([*fe, 0.0]) # 测试时没有loss设为0
if args['real_value']:
# 只对输出维度进行反归一化
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算每个时间步的指标
for t in range(y_true.shape[1]):
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 保存nfe数据如果启用
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
Trainer._save_nfe_data_static(c, model, logger)
# 只在需要时生成可视化图片
if generate_viz:
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=scalers)
@staticmethod
def _inverse_transform_output_static(output, args, scalers):
"""
静态方法只对输出维度进行反归一化
"""
if not args['real_value']:
return output
# 获取输出维度的数量
output_dim = args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(scalers):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(scalers)):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
@staticmethod
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
"""
生成节点预测可视化图片
Args:
y_pred: 预测值
y_true: 真实值
logger: 日志记录器
save_dir: 保存目录
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置减少字体查找输出
matplotlib.set_loglevel('error') # 只显示错误信息
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
# 检查数据有效性
if y_pred is None or y_true is None:
return
# 创建pic文件夹
pic_dir = os.path.join(save_dir, 'pic')
os.makedirs(pic_dir, exist_ok=True)
# 固定生成10张图片
num_nodes_to_plot = 10
# 生成单个节点的详细图
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
for node_id in range(num_nodes_to_plot):
# 获取对应节点的数据
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻指定节点第一个特征
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
pbar.update(1)
continue
else:
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
pbar.update(1)
continue
# 取前500个时间步
max_steps = min(500, len(node_pred))
if max_steps <= 0:
pbar.update(1)
continue
node_pred_500 = node_pred[:max_steps]
node_true_500 = node_true[:max_steps]
# 创建时间轴
time_steps = np.arange(max_steps)
# 绘制对比图
plt.figure(figsize=(12, 6))
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
plt.legend()
plt.grid(True, alpha=0.3)
# 保存图片,使用不同的命名
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成所有节点的对比图前100个时间步便于观察
# 选择前100个时间步
plot_steps = min(100, y_pred.shape[0])
if plot_steps <= 0:
return
# 创建子图
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for node_id in range(num_nodes_to_plot):
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
else:
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
time_steps = np.arange(plot_steps)
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
axes[node_id].set_title(f'Node {node_id + 1}')
axes[node_id].grid(True, alpha=0.3)
axes[node_id].legend(fontsize=8)
if node_id >= 5: # 下面一行添加x轴标签
axes[node_id].set_xlabel('Time Steps')
if node_id % 5 == 0: # 左边一列添加y轴标签
axes[node_id].set_ylabel('Values')
plt.tight_layout()
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=None):
"""
生成输入-输出样本比较图
Args:
y_pred: 预测值
y_true: 真实值
data_loader: 数据加载器用于获取输入数据
logger: 日志记录器
save_dir: 保存目录
target_node: 目标节点ID从1开始
num_samples: 要比较的样本数量
scalers: 标准化器列表用于反归一化输入数据
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置
matplotlib.set_loglevel('error')
plt.rcParams['font.family'] = 'DejaVu Sans'
# 创建compare文件夹
compare_dir = os.path.join(save_dir, 'pic', 'compare')
os.makedirs(compare_dir, exist_ok=True)
# 获取输入数据
input_data = []
for batch_idx, (data, target) in enumerate(data_loader):
if batch_idx >= num_samples:
break
input_data.append(data.cpu().numpy())
if not input_data:
return
# 获取目标节点的索引从0开始
node_idx = target_node - 1
# 检查节点索引是否有效
if node_idx >= y_pred.shape[-2]:
return
# 为每个样本生成比较图
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
for sample_idx in range(min(num_samples, len(input_data))):
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features]
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch所有时间步目标节点第一个特征
# 对输入数据进行反归一化
if scalers is not None and len(scalers) > 0:
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
# 获取对应的预测值和真实值
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon目标节点第一个特征
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
pbar.update(1)
continue
# 创建时间轴 - 输入和输出连续
total_time = np.arange(len(input_seq) + len(pred_seq))
# 创建合并的图形 - 输入和输出在同一个图中
plt.figure(figsize=(14, 8))
# 绘制完整的真实值曲线(输入 + 真实输出)
true_combined = np.concatenate([input_seq, true_seq])
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
linewidth=2.5, alpha=0.9, linestyle='-')
# 绘制预测值曲线(只绘制输出部分)
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
linewidth=2, alpha=0.8, linestyle='-')
# 添加垂直线分隔输入和输出
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
label='Input/Output Boundary')
# 设置图形属性
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
plt.legend()
plt.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 保存图片
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成汇总图(所有样本的预测值对比)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for sample_idx in range(min(num_samples, len(input_data))):
if sample_idx >= 10: # 最多显示10个子图
break
ax = axes[sample_idx]
# 获取输入序列和预测值、真实值
input_seq = input_data[sample_idx][0, :, node_idx, 0]
if scalers is not None and len(scalers) > 0:
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
ha='center', va='center', transform=ax.transAxes)
continue
# 绘制对比图 - 输入和输出连续显示
total_time = np.arange(len(input_seq) + len(pred_seq))
true_combined = np.concatenate([input_seq, true_seq])
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
ax.set_title(f'Sample {sample_idx + 1}')
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
if sample_idx >= 5: # 下面一行添加x轴标签
ax.set_xlabel('Time Steps')
if sample_idx % 5 == 0: # 左边一列添加y轴标签
ax.set_ylabel('Values')
# 隐藏多余的子图
for i in range(min(num_samples, len(input_data)), 10):
axes[i].set_visible(False)
plt.tight_layout()
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
def _save_nfe_data(self):
"""保存nfe数据到文件"""
if not self.res:
return
res = pd.concat(self.res, keys=self.keys)
res.index.names = ['epoch', 'iter']
# 获取模型配置参数
filter_type = getattr(self.model, 'filter_type', 'unknown')
atol = getattr(self.model, 'atol', 1e-5)
rtol = getattr(self.model, 'rtol', 1e-5)
# 保存nfe数据
nfe_file = os.path.join(
self.logger.dir_path,
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)
self.logger.logger.info(f"NFE data saved to {nfe_file}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -1,11 +1,14 @@
from trainer.trainer import Trainer
from trainer.ode_trainer import Trainer as ode_trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs):
model_name = config['basic']['model']
selected_Trainer = None
match model_name:
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler)
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,lr_scheduler)
train_loader, val_loader, test_loader, scaler, lr_scheduler)
if selected_Trainer is None: raise NotImplementedError
return selected_Trainer