Compare commits
7 Commits
19a02ba7ae
...
f0d3460c89
| Author | SHA1 | Date |
|---|---|---|
|
|
f0d3460c89 | |
|
|
387f64efab | |
|
|
56a7f309fa | |
|
|
a8cc3a20fd | |
|
|
ab5811425d | |
|
|
66a23ffbbb | |
|
|
df8c573f4c |
|
|
@ -160,3 +160,8 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
.STDEN/
|
||||||
|
.data/PEMS08/
|
||||||
|
exp/
|
||||||
|
STDEN/
|
||||||
|
models/gpt2/
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<Languages>
|
||||||
|
<language minSize="136" name="Python" />
|
||||||
|
</Languages>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="8">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,216 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="AutoImportSettings">
|
||||||
|
<option name="autoReloadType" value="SELECTIVE" />
|
||||||
|
</component>
|
||||||
|
<component name="ChangeListManager">
|
||||||
|
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
||||||
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
|
||||||
|
</list>
|
||||||
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||||
|
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||||
|
</component>
|
||||||
|
<component name="FileTemplateManagerImpl">
|
||||||
|
<option name="RECENT_TEMPLATES">
|
||||||
|
<list>
|
||||||
|
<option value="Python Script" />
|
||||||
|
</list>
|
||||||
|
</option>
|
||||||
|
</component>
|
||||||
|
<component name="Git.Settings">
|
||||||
|
<excluded-from-favorite>
|
||||||
|
<branch-storage>
|
||||||
|
<map>
|
||||||
|
<entry type="LOCAL">
|
||||||
|
<value>
|
||||||
|
<list>
|
||||||
|
<branch-info repo="$PROJECT_DIR$" source="main" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</entry>
|
||||||
|
</map>
|
||||||
|
</branch-storage>
|
||||||
|
</excluded-from-favorite>
|
||||||
|
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
||||||
|
<map>
|
||||||
|
<entry key="$PROJECT_DIR$" value="main" />
|
||||||
|
</map>
|
||||||
|
</option>
|
||||||
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
|
<option name="ROOT_SYNC" value="DONT_SYNC" />
|
||||||
|
</component>
|
||||||
|
<component name="MarkdownSettingsMigration">
|
||||||
|
<option name="stateVersion" value="1" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectColorInfo">{
|
||||||
|
"associatedIndex": 3
|
||||||
|
}</component>
|
||||||
|
<component name="ProjectId" id="3264JlB7seHXuXCCcdmTyEsXI45" />
|
||||||
|
<component name="ProjectViewState">
|
||||||
|
<option name="hideEmptyMiddlePackages" value="true" />
|
||||||
|
<option name="showLibraryContents" value="true" />
|
||||||
|
</component>
|
||||||
|
<component name="PropertiesComponent"><![CDATA[{
|
||||||
|
"keyToString": {
|
||||||
|
"Python.STDEN.executor": "Debug",
|
||||||
|
"Python.STGODE.executor": "Run",
|
||||||
|
"Python.main.executor": "Run",
|
||||||
|
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||||
|
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||||
|
"git-widget-placeholder": "STGODE",
|
||||||
|
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
||||||
|
"node.js.detected.package.eslint": "true",
|
||||||
|
"node.js.detected.package.tslint": "true",
|
||||||
|
"node.js.selected.package.eslint": "(autodetect)",
|
||||||
|
"node.js.selected.package.tslint": "(autodetect)",
|
||||||
|
"nodejs_package_manager_path": "npm",
|
||||||
|
"vue.rearranger.settings.migration": "true"
|
||||||
|
}
|
||||||
|
}]]></component>
|
||||||
|
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
|
||||||
|
<layout>
|
||||||
|
<window_info id="Space Code Reviews" show_stripe_button="false" />
|
||||||
|
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
|
||||||
|
<window_info id="Merge Requests" show_stripe_button="false" />
|
||||||
|
<window_info id="Commit_Guest" show_stripe_button="false" />
|
||||||
|
<window_info id="Pull Requests" show_stripe_button="false" />
|
||||||
|
<window_info id="Learn" show_stripe_button="false" />
|
||||||
|
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
|
||||||
|
<window_info id="Commit" order="1" weight="0.25" />
|
||||||
|
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
||||||
|
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
|
||||||
|
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
|
||||||
|
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
|
||||||
|
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
||||||
|
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
||||||
|
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
|
||||||
|
<window_info anchor="bottom" id="Version Control" order="0" />
|
||||||
|
<window_info anchor="bottom" id="Problems" order="1" />
|
||||||
|
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
||||||
|
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
|
||||||
|
<window_info anchor="bottom" id="Services" order="4" />
|
||||||
|
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
||||||
|
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
||||||
|
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
|
||||||
|
<window_info anchor="right" id="SciView" show_stripe_button="false" />
|
||||||
|
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
|
||||||
|
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
|
||||||
|
<window_info anchor="right" id="Database" order="2" weight="0.25" />
|
||||||
|
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
|
||||||
|
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
|
||||||
|
<window_info anchor="right" id="Plots" order="5" weight="0.1" />
|
||||||
|
</layout>
|
||||||
|
</component>
|
||||||
|
<component name="RecentsManager">
|
||||||
|
<key name="CopyFile.RECENT_KEYS">
|
||||||
|
<recent name="$PROJECT_DIR$/trainer" />
|
||||||
|
<recent name="$PROJECT_DIR$/configs/STDEN" />
|
||||||
|
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||||
|
</key>
|
||||||
|
<key name="MoveFile.RECENT_KEYS">
|
||||||
|
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||||
|
</key>
|
||||||
|
</component>
|
||||||
|
<component name="RunManager" selected="Python.STGODE">
|
||||||
|
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
|
||||||
|
<module name="Project-I" />
|
||||||
|
<option name="ENV_FILES" value="" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="SDK_NAME" value="TS" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="" />
|
||||||
|
<option name="IS_MODULE_SDK" value="false" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||||
|
<option name="PARAMETERS" value="--config ./configs/STDEN/PEMS08.yaml" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
|
||||||
|
<module name="Project-I" />
|
||||||
|
<option name="ENV_FILES" value="" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="SDK_NAME" value="TS" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="" />
|
||||||
|
<option name="IS_MODULE_SDK" value="false" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||||
|
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<list>
|
||||||
|
<item itemvalue="Python.STDEN" />
|
||||||
|
<item itemvalue="Python.STGODE" />
|
||||||
|
</list>
|
||||||
|
</component>
|
||||||
|
<component name="SharedIndexes">
|
||||||
|
<attachedChunks>
|
||||||
|
<set>
|
||||||
|
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
|
||||||
|
</set>
|
||||||
|
</attachedChunks>
|
||||||
|
</component>
|
||||||
|
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
|
||||||
|
<component name="TaskManager">
|
||||||
|
<task active="true" id="Default" summary="默认任务">
|
||||||
|
<changelist id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="" />
|
||||||
|
<created>1756727620810</created>
|
||||||
|
<option name="number" value="Default" />
|
||||||
|
<option name="presentableId" value="Default" />
|
||||||
|
<updated>1756727620810</updated>
|
||||||
|
<workItem from="1756727623101" duration="4721000" />
|
||||||
|
<workItem from="1756856673845" duration="652000" />
|
||||||
|
<workItem from="1756864144998" duration="1063000" />
|
||||||
|
</task>
|
||||||
|
<servers />
|
||||||
|
</component>
|
||||||
|
<component name="TypeScriptGeneratedFilesManager">
|
||||||
|
<option name="version" value="3" />
|
||||||
|
</component>
|
||||||
|
<component name="XDebuggerManager">
|
||||||
|
<breakpoint-manager>
|
||||||
|
<breakpoints>
|
||||||
|
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||||
|
<url>file://$PROJECT_DIR$/models/STDEN/stden_model.py</url>
|
||||||
|
<line>131</line>
|
||||||
|
<option name="timeStamp" value="5" />
|
||||||
|
</line-breakpoint>
|
||||||
|
</breakpoints>
|
||||||
|
</breakpoint-manager>
|
||||||
|
</component>
|
||||||
|
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||||
|
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||||
|
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2
|
||||||
|
|
@ -0,0 +1,65 @@
|
||||||
|
basic:
|
||||||
|
device: cuda:0
|
||||||
|
dataset: PEMS08
|
||||||
|
model: STDEN
|
||||||
|
mode: train
|
||||||
|
seed: 2025
|
||||||
|
|
||||||
|
data:
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
num_nodes: 170
|
||||||
|
batch_size: 32
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
horizon: 24
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 24
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss: mae
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 100
|
||||||
|
lr_init: 0.003
|
||||||
|
mape_thresh: 0.001
|
||||||
|
mae_thresh: None
|
||||||
|
debug: False
|
||||||
|
output_dim: 1
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_GM
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_GM
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_GM.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_RM
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_RM
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_RM.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64 # for recognition
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0 # 0 for not load
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
---
|
||||||
|
log_base_dir: logs/BJ_XZ
|
||||||
|
log_level: INFO
|
||||||
|
|
||||||
|
data:
|
||||||
|
batch_size: 32
|
||||||
|
dataset_dir: data/BJ_XZ
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/sensor_graph/adj_XZ.npy
|
||||||
|
|
||||||
|
model:
|
||||||
|
l1_decay: 0
|
||||||
|
seq_len: 12
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 0.00001
|
||||||
|
odeint_rtol: 0.00001
|
||||||
|
rnn_units: 64
|
||||||
|
num_rnn_layers: 1
|
||||||
|
gcn_step: 2
|
||||||
|
filter_type: default # unkP IncP default
|
||||||
|
recg_type: gru
|
||||||
|
save_latent: false
|
||||||
|
nfe: false
|
||||||
|
|
||||||
|
train:
|
||||||
|
base_lr: 0.01
|
||||||
|
dropout: 0
|
||||||
|
load: 0 # 0 for not load
|
||||||
|
epoch: 0
|
||||||
|
epochs: 100
|
||||||
|
epsilon: 1.0e-3
|
||||||
|
lr_decay_ratio: 0.1
|
||||||
|
max_grad_norm: 5
|
||||||
|
min_learning_rate: 2.0e-06
|
||||||
|
optimizer: adam
|
||||||
|
patience: 20
|
||||||
|
steps: [20, 30, 40, 50]
|
||||||
|
test_every_n_epochs: 5
|
||||||
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
basic:
|
||||||
|
device: cuda:0
|
||||||
|
dataset: PEMS08
|
||||||
|
model: STGODE
|
||||||
|
mode: train
|
||||||
|
seed: 2025
|
||||||
|
|
||||||
|
data:
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
num_nodes: 170
|
||||||
|
batch_size: 64
|
||||||
|
input_dim: 1
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 24
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
num_features: 1
|
||||||
|
rnn_units: 64
|
||||||
|
sigma1: 0.1
|
||||||
|
sigma2: 10
|
||||||
|
thres1: 0.6
|
||||||
|
thres2: 0.5
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss: mae
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 100
|
||||||
|
lr_init: 0.003
|
||||||
|
mape_thresh: 0.001
|
||||||
|
mae_thresh: None
|
||||||
|
debug: False
|
||||||
|
output_dim: 1
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
log_step: 3000
|
||||||
|
|
||||||
|
|
@ -0,0 +1,296 @@
|
||||||
|
from,to,cost
|
||||||
|
9,153,310.6
|
||||||
|
153,62,330.9
|
||||||
|
62,111,332.9
|
||||||
|
111,11,324.2
|
||||||
|
11,28,336.0
|
||||||
|
28,169,133.7
|
||||||
|
138,135,354.7
|
||||||
|
135,133,387.9
|
||||||
|
133,163,337.1
|
||||||
|
163,20,352.0
|
||||||
|
20,19,420.8
|
||||||
|
19,14,351.3
|
||||||
|
14,39,340.2
|
||||||
|
39,164,350.3
|
||||||
|
164,167,365.2
|
||||||
|
167,70,359.0
|
||||||
|
70,59,388.2
|
||||||
|
59,58,305.7
|
||||||
|
58,67,294.4
|
||||||
|
67,66,299.5
|
||||||
|
66,55,313.3
|
||||||
|
55,53,332.1
|
||||||
|
53,150,278.9
|
||||||
|
150,61,308.4
|
||||||
|
61,64,311.4
|
||||||
|
64,63,243.6
|
||||||
|
47,65,372.8
|
||||||
|
65,48,319.4
|
||||||
|
48,49,309.7
|
||||||
|
49,54,320.5
|
||||||
|
54,56,318.3
|
||||||
|
56,57,297.9
|
||||||
|
57,68,293.5
|
||||||
|
68,69,342.5
|
||||||
|
69,60,318.0
|
||||||
|
60,17,305.9
|
||||||
|
17,5,321.4
|
||||||
|
5,18,402.2
|
||||||
|
18,22,447.4
|
||||||
|
22,30,377.5
|
||||||
|
30,29,417.7
|
||||||
|
29,21,360.8
|
||||||
|
21,132,407.6
|
||||||
|
132,134,386.9
|
||||||
|
134,136,350.2
|
||||||
|
123,121,326.3
|
||||||
|
121,140,385.2
|
||||||
|
140,118,393.0
|
||||||
|
118,96,296.7
|
||||||
|
96,94,398.2
|
||||||
|
94,86,337.1
|
||||||
|
86,78,473.8
|
||||||
|
78,46,353.4
|
||||||
|
46,152,385.7
|
||||||
|
152,157,350.0
|
||||||
|
157,35,354.4
|
||||||
|
35,77,356.1
|
||||||
|
77,52,354.2
|
||||||
|
52,3,357.8
|
||||||
|
3,16,382.4
|
||||||
|
16,0,55.7
|
||||||
|
42,12,335.1
|
||||||
|
12,139,328.8
|
||||||
|
139,168,412.6
|
||||||
|
168,154,337.3
|
||||||
|
154,143,370.7
|
||||||
|
143,10,6.3
|
||||||
|
107,105,354.6
|
||||||
|
105,104,386.9
|
||||||
|
104,148,362.1
|
||||||
|
148,97,316.3
|
||||||
|
97,101,380.7
|
||||||
|
101,137,361.4
|
||||||
|
137,102,365.5
|
||||||
|
102,24,375.5
|
||||||
|
24,166,312.2
|
||||||
|
129,156,256.1
|
||||||
|
156,33,329.1
|
||||||
|
33,32,356.5
|
||||||
|
91,89,405.6
|
||||||
|
89,147,347.0
|
||||||
|
147,15,351.7
|
||||||
|
15,44,339.5
|
||||||
|
44,41,350.8
|
||||||
|
41,43,322.6
|
||||||
|
43,100,338.9
|
||||||
|
100,83,347.9
|
||||||
|
83,87,327.2
|
||||||
|
87,88,321.0
|
||||||
|
88,75,335.8
|
||||||
|
75,51,384.8
|
||||||
|
51,73,391.1
|
||||||
|
73,71,289.3
|
||||||
|
31,155,260.0
|
||||||
|
155,34,320.4
|
||||||
|
34,128,393.3
|
||||||
|
145,115,399.4
|
||||||
|
115,112,328.1
|
||||||
|
112,8,469.4
|
||||||
|
8,117,816.2
|
||||||
|
117,125,397.1
|
||||||
|
125,127,372.7
|
||||||
|
127,109,380.5
|
||||||
|
109,161,355.5
|
||||||
|
161,110,367.7
|
||||||
|
110,160,102.0
|
||||||
|
72,159,342.9
|
||||||
|
159,50,383.3
|
||||||
|
50,74,354.1
|
||||||
|
74,82,350.2
|
||||||
|
82,81,335.4
|
||||||
|
81,99,391.6
|
||||||
|
99,84,354.9
|
||||||
|
84,13,306.4
|
||||||
|
13,40,327.4
|
||||||
|
40,162,413.9
|
||||||
|
162,108,301.9
|
||||||
|
108,146,317.8
|
||||||
|
146,85,376.6
|
||||||
|
85,90,347.0
|
||||||
|
26,27,341.6
|
||||||
|
27,6,359.4
|
||||||
|
6,149,417.8
|
||||||
|
149,126,388.0
|
||||||
|
126,124,384.3
|
||||||
|
124,7,763.3
|
||||||
|
7,114,323.1
|
||||||
|
114,113,351.6
|
||||||
|
113,116,411.9
|
||||||
|
116,144,262.0
|
||||||
|
25,103,350.2
|
||||||
|
103,23,376.3
|
||||||
|
23,165,396.4
|
||||||
|
165,38,381.0
|
||||||
|
38,92,368.0
|
||||||
|
92,37,336.3
|
||||||
|
37,130,357.8
|
||||||
|
130,106,532.3
|
||||||
|
106,131,166.5
|
||||||
|
1,2,371.6
|
||||||
|
2,4,338.1
|
||||||
|
4,76,429.0
|
||||||
|
76,36,366.1
|
||||||
|
36,158,344.5
|
||||||
|
158,151,350.1
|
||||||
|
151,45,358.8
|
||||||
|
45,93,340.9
|
||||||
|
93,80,329.9
|
||||||
|
80,79,384.1
|
||||||
|
79,95,335.7
|
||||||
|
95,98,320.9
|
||||||
|
98,119,340.3
|
||||||
|
119,120,376.8
|
||||||
|
120,122,393.1
|
||||||
|
122,141,428.7
|
||||||
|
141,142,359.3
|
||||||
|
30,165,379.6
|
||||||
|
165,29,41.7
|
||||||
|
29,38,343.3
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
17,153,375.6
|
||||||
|
153,5,256.3
|
||||||
|
153,62,330.9
|
||||||
|
18,6,499.4
|
||||||
|
6,22,254.0
|
||||||
|
22,149,185.4
|
||||||
|
22,4,257.9
|
||||||
|
4,30,236.8
|
||||||
|
30,76,307.0
|
||||||
|
95,98,320.9
|
||||||
|
98,144,45.1
|
||||||
|
45,93,340.9
|
||||||
|
93,106,112.2
|
||||||
|
162,151,113.6
|
||||||
|
151,108,192.9
|
||||||
|
108,45,359.8
|
||||||
|
146,92,311.2
|
||||||
|
92,85,343.9
|
||||||
|
85,37,373.2
|
||||||
|
13,169,326.2
|
||||||
|
169,40,96.1
|
||||||
|
124,13,460.7
|
||||||
|
13,7,305.5
|
||||||
|
7,40,624.1
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
90,132,152.2
|
||||||
|
26,32,106.7
|
||||||
|
9,129,148.3
|
||||||
|
129,153,219.6
|
||||||
|
31,26,116.0
|
||||||
|
26,155,270.7
|
||||||
|
9,128,142.2
|
||||||
|
128,153,215.0
|
||||||
|
153,167,269.7
|
||||||
|
167,62,64.8
|
||||||
|
62,70,332.6
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
44,169,397.8
|
||||||
|
169,41,124.0
|
||||||
|
44,124,375.7
|
||||||
|
124,41,243.9
|
||||||
|
41,7,519.4
|
||||||
|
6,14,289.3
|
||||||
|
14,149,259.0
|
||||||
|
149,39,206.9
|
||||||
|
144,98,45.1
|
||||||
|
19,4,326.8
|
||||||
|
4,14,178.6
|
||||||
|
14,76,299.0
|
||||||
|
15,151,136.4
|
||||||
|
151,44,203.1
|
||||||
|
45,106,260.6
|
||||||
|
106,93,112.2
|
||||||
|
20,165,132.5
|
||||||
|
165,19,289.2
|
||||||
|
89,92,323.2
|
||||||
|
92,147,321.9
|
||||||
|
147,37,48.2
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
104,147,277.5
|
||||||
|
147,148,84.7
|
||||||
|
20,101,201.2
|
||||||
|
101,19,534.4
|
||||||
|
19,137,245.5
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
44,42,342.3
|
||||||
|
42,41,102.5
|
||||||
|
44,8,789.1
|
||||||
|
8,41,657.4
|
||||||
|
41,117,160.5
|
||||||
|
168,167,172.4
|
||||||
|
167,154,165.2
|
||||||
|
143,128,81.9
|
||||||
|
128,10,88.2
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
15,152,135.0
|
||||||
|
152,44,204.6
|
||||||
|
19,77,320.7
|
||||||
|
77,14,299.8
|
||||||
|
14,52,127.6
|
||||||
|
14,127,314.8
|
||||||
|
127,39,280.4
|
||||||
|
39,109,237.0
|
||||||
|
31,160,116.5
|
||||||
|
160,155,272.4
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
32,160,107.7
|
||||||
|
72,162,3274.4
|
||||||
|
162,13,554.5
|
||||||
|
162,40,413.9
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
13,42,319.8
|
||||||
|
42,40,40.7
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
8,13,450.3
|
||||||
|
13,117,378.5
|
||||||
|
117,40,64.0
|
||||||
|
46,162,391.6
|
||||||
|
162,152,115.3
|
||||||
|
152,108,191.4
|
||||||
|
104,108,375.9
|
||||||
|
108,148,311.6
|
||||||
|
148,146,80.0
|
||||||
|
21,90,396.9
|
||||||
|
90,132,152.2
|
||||||
|
101,29,252.3
|
||||||
|
29,137,110.7
|
||||||
|
77,22,353.8
|
||||||
|
22,52,227.8
|
||||||
|
52,30,186.6
|
||||||
|
127,18,425.2
|
||||||
|
18,109,439.1
|
||||||
|
109,22,135.5
|
||||||
|
168,17,232.7
|
||||||
|
17,154,294.2
|
||||||
|
154,5,166.3
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def load_dataset(config):
|
def load_dataset(config):
|
||||||
|
|
||||||
dataset_name = config['basic']['dataset']
|
dataset_name = config['basic']['dataset']
|
||||||
node_num = config['data']['num_nodes']
|
node_num = config['data']['num_nodes']
|
||||||
input_dim = config['data']['input_dim']
|
input_dim = config['data']['input_dim']
|
||||||
|
|
@ -10,4 +11,8 @@ def load_dataset(config):
|
||||||
case 'EcoSolar':
|
case 'EcoSolar':
|
||||||
data_path = os.path.join('./data/EcoSolar.npy')
|
data_path = os.path.join('./data/EcoSolar.npy')
|
||||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||||
|
case 'PEMS08':
|
||||||
|
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||||
|
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def load_graph(config):
|
||||||
|
dataset_path = config['data']['graph_pkl_filename']
|
||||||
|
graph = np.load(dataset_path)
|
||||||
|
# 将inf值填充为0
|
||||||
|
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
return graph
|
||||||
2
main.py
2
main.py
|
|
@ -3,8 +3,6 @@
|
||||||
时空数据深度学习预测项目主程序
|
时空数据深度学习预测项目主程序
|
||||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from utils.args_reader import config_loader
|
from utils.args_reader import config_loader
|
||||||
import utils.init as init
|
import utils.init as init
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,88 @@
|
||||||
|
### STDEN 模块与执行流(缩进层级表)
|
||||||
|
|
||||||
|
模块 | 类/函数 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | STDENModel.forward | inputs: (seq_len, batch_size, num_edges x input_dim) | outputs: (horizon, batch_size, num_edges x output_dim); fe: (nfe:int, time:float)
|
||||||
|
1.1 | Encoder_z0_RNN.forward | (seq_len, batch_size, num_edges x input_dim) | mean: (1, batch_size, num_nodes x latent_dim); std: (1, batch_size, num_nodes x latent_dim)
|
||||||
|
1.1.1 | utils.sample_standard_gaussian | mu: (n_traj, batch, num_nodes x latent_dim); sigma: 同形状 | z0: (n_traj, batch, num_nodes x latent_dim)
|
||||||
|
1.2 | DiffeqSolver.forward | first_point: (n_traj, batch, num_nodes x latent_dim); t: (horizon,) | sol_ys: (horizon, n_traj, batch, num_nodes x latent_dim); fe: (nfe:int, time:float)
|
||||||
|
1.2.1 | ODEFunc.forward | t_local: 标量/1D; y: (B, num_nodes x latent_dim) | dy/dt: (B, num_nodes x latent_dim)
|
||||||
|
1.3 | Decoder.forward | (horizon, n_traj, batch, num_nodes x latent_dim) | (horizon, batch, num_edges x output_dim)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 细节模块 — Encoder_z0_RNN
|
||||||
|
|
||||||
|
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | 重塑到边批 | (seq_len, batch, num_edges x input_dim) | (seq_len, batch, num_edges, input_dim)
|
||||||
|
2 | 合并边到批 | (seq_len, batch, num_edges, input_dim) | (seq_len, batch x num_edges, input_dim)
|
||||||
|
3 | GRU 序列编码 | 同上 | (seq_len, batch x num_edges, rnn_units)
|
||||||
|
4 | 取最后时间步 | 同上 | (batch x num_edges, rnn_units)
|
||||||
|
5 | 还原边维 | (batch x num_edges, rnn_units) | (batch, num_edges, rnn_units)
|
||||||
|
6 | 转置 + 边→节点映射 | (batch, num_edges, rnn_units) 经 inv_grad | (batch, num_nodes, rnn_units)
|
||||||
|
7 | 全连接映射到 2x latent | (batch, num_nodes, rnn_units) | (batch, num_nodes, 2 x latent_dim)
|
||||||
|
8 | 拆分均值/标准差 | 同上 | mean/std: (batch, num_nodes, latent_dim)
|
||||||
|
9 | 展平并加时间维 | (batch, num_nodes, latent_dim) | (1, batch, num_nodes x latent_dim)
|
||||||
|
|
||||||
|
备注:inv_grad 来源于 `utils.graph_grad(adj).T` 并做缩放;`hiddens_to_z0` 为两层 MLP + Tanh 后线性映射至 2 x latent_dim。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 细节模块 — 采样(utils.sample_standard_gaussian)
|
||||||
|
|
||||||
|
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | 重复到 n_traj | mean/std: (1, batch, N·Z) → 重复 | (n_traj, batch, N·Z)
|
||||||
|
2 | 重参数化采样 | mu, sigma | z0: (n_traj, batch, N·Z)
|
||||||
|
|
||||||
|
其中 N·Z = num_nodes x latent_dim。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 细节模块 — DiffeqSolver(含 ODEFunc 调用)
|
||||||
|
|
||||||
|
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | 合并样本维度 | first_point: (n_traj, batch, N·Z) | (n_traj x batch, N·Z)
|
||||||
|
2 | ODE 积分 | t: (horizon,), y0 | pred_y: (horizon, n_traj x batch, N·Z)
|
||||||
|
3 | 还原维度 | 同上 | (horizon, n_traj, batch, N·Z)
|
||||||
|
4 | 统计代价 | odefunc.nfe, elapsed_time | fe: (nfe:int, time:float)
|
||||||
|
|
||||||
|
ODEFunc 默认(filter_type="default")为扩散过程:随机游走支持 + 多阶图卷积门控。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 细节模块 — ODEFunc(默认扩散过程)
|
||||||
|
|
||||||
|
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | 形状整理 | y: (B, N·Z) → (B, N, Z) | (B, N, Z)
|
||||||
|
2 | 多阶图卷积 _gconv | (B, N, Z) | (B, N, Z') 按需设置 Z'(通常保持 Z)
|
||||||
|
3 | 门控 θ | _gconv(..., output=latent_dim) → Sigmoid | θ: (B, N·Z)
|
||||||
|
4 | 生成场 ode_func_net | 堆叠 _gconv + 激活 | f(y): (B, N·Z)
|
||||||
|
5 | 右端梯度 | - θ ⊙ f(y) | dy/dt: (B, N·Z)
|
||||||
|
|
||||||
|
说明:
|
||||||
|
- 支撑矩阵来自 `utils.calculate_random_walk_matrix(adj)`(正向/反向)并构造稀疏 Chebyshev 递推的多阶通道。
|
||||||
|
- 若 `filter_type="unkP"`,则使用 `create_net` 的全连接网络在节点域逐点计算梯度。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 细节模块 — Decoder
|
||||||
|
|
||||||
|
步骤 | 操作 | 输入 (shape) | 输出 (shape)
|
||||||
|
--- | --- | --- | ---
|
||||||
|
1 | 重塑到节点域 | (T, S, B, N·Z) → (T, S, B, N, Z) | (T, S, B, N, Z)
|
||||||
|
2 | 节点→边映射 | 乘以 graph_grad (N, E) | (T, S, B, Z, E)
|
||||||
|
3 | 轨迹与通道均值 | 对 S 和 Z 维做均值 | (T, B, E)
|
||||||
|
4 | 展平到输出维 | 考虑 output_dim(通常为 1) | (T, B, E x output_dim)
|
||||||
|
|
||||||
|
符号:T=horizon,S=n_traj_samples,N=num_nodes,E=num_edges,Z=latent_dim,B=batch。
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 备注与约定
|
||||||
|
- 内部采用边展平后的时序输入:`(seq_len, batch, num_edges x input_dim)`。
|
||||||
|
- 图算子:`utils.graph_grad(adj)` 形状 `(N, E)`;`utils.calculate_random_walk_matrix(adj)` 生成随机游走稀疏矩阵用于图卷积。
|
||||||
|
- 关键超参数(由配置传入):`latent_dim`, `rnn_units`, `gcn_step`, `n_traj_samples`, `ode_method`, `horizon`, `input_dim`, `output_dim`。
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import time
|
||||||
|
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class DiffeqSolver(nn.Module):
|
||||||
|
def __init__(self, odefunc, method, latent_dim,
|
||||||
|
odeint_rtol = 1e-4, odeint_atol = 1e-5):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
|
self.ode_method = method
|
||||||
|
self.odefunc = odefunc
|
||||||
|
self.latent_dim = latent_dim
|
||||||
|
|
||||||
|
self.rtol = odeint_rtol
|
||||||
|
self.atol = odeint_atol
|
||||||
|
|
||||||
|
def forward(self, first_point, time_steps_to_pred):
|
||||||
|
"""
|
||||||
|
Decoder the trajectory through the ODE Solver.
|
||||||
|
|
||||||
|
:param time_steps_to_pred: horizon
|
||||||
|
:param first_point: (n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
:return: pred_y: # shape (horizon, n_traj_samples, batch_size, self.num_nodes * self.output_dim)
|
||||||
|
"""
|
||||||
|
n_traj_samples, batch_size = first_point.size()[0], first_point.size()[1]
|
||||||
|
first_point = first_point.reshape(n_traj_samples * batch_size, -1) # reduce the complexity by merging dimension
|
||||||
|
|
||||||
|
# pred_y shape: (horizon, n_traj_samples * batch_size, num_nodes * latent_dim)
|
||||||
|
start_time = time.time()
|
||||||
|
self.odefunc.nfe = 0
|
||||||
|
pred_y = odeint(self.odefunc,
|
||||||
|
first_point,
|
||||||
|
time_steps_to_pred,
|
||||||
|
rtol=self.rtol,
|
||||||
|
atol=self.atol,
|
||||||
|
method=self.ode_method)
|
||||||
|
time_fe = time.time() - start_time
|
||||||
|
|
||||||
|
# pred_y shape: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
pred_y = pred_y.reshape(pred_y.size()[0], n_traj_samples, batch_size, -1)
|
||||||
|
# assert(pred_y.size()[1] == n_traj_samples)
|
||||||
|
# assert(pred_y.size()[2] == batch_size)
|
||||||
|
|
||||||
|
return pred_y, (self.odefunc.nfe, time_fe)
|
||||||
|
|
||||||
|
|
@ -0,0 +1,165 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from models.STDEN import utils
|
||||||
|
|
||||||
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||||
|
|
||||||
|
class LayerParams:
|
||||||
|
def __init__(self, rnn_network: nn.Module, layer_type: str):
|
||||||
|
self._rnn_network = rnn_network
|
||||||
|
self._params_dict = {}
|
||||||
|
self._biases_dict = {}
|
||||||
|
self._type = layer_type
|
||||||
|
|
||||||
|
def get_weights(self, shape):
|
||||||
|
if shape not in self._params_dict:
|
||||||
|
nn_param = nn.Parameter(torch.empty(*shape, device=device))
|
||||||
|
nn.init.xavier_normal_(nn_param)
|
||||||
|
self._params_dict[shape] = nn_param
|
||||||
|
self._rnn_network.register_parameter('{}_weight_{}'.format(self._type, str(shape)),
|
||||||
|
nn_param)
|
||||||
|
return self._params_dict[shape]
|
||||||
|
|
||||||
|
def get_biases(self, length, bias_start=0.0):
|
||||||
|
if length not in self._biases_dict:
|
||||||
|
biases = nn.Parameter(torch.empty(length, device=device))
|
||||||
|
nn.init.constant_(biases, bias_start)
|
||||||
|
self._biases_dict[length] = biases
|
||||||
|
self._rnn_network.register_parameter('{}_biases_{}'.format(self._type, str(length)),
|
||||||
|
biases)
|
||||||
|
|
||||||
|
return self._biases_dict[length]
|
||||||
|
|
||||||
|
class ODEFunc(nn.Module):
|
||||||
|
def __init__(self, num_units, latent_dim, adj_mx, gcn_step, num_nodes,
|
||||||
|
gen_layers=1, nonlinearity='tanh', filter_type="default"):
|
||||||
|
"""
|
||||||
|
:param num_units: dimensionality of the hidden layers
|
||||||
|
:param latent_dim: dimensionality used for ODE (input and output). Analog of a continous latent state
|
||||||
|
:param adj_mx:
|
||||||
|
:param gcn_step:
|
||||||
|
:param num_nodes:
|
||||||
|
:param gen_layers: hidden layers in each ode func.
|
||||||
|
:param nonlinearity:
|
||||||
|
:param filter_type: default
|
||||||
|
:param use_gc_for_ru: whether to use Graph convolution to calculate the reset and update gates.
|
||||||
|
"""
|
||||||
|
super(ODEFunc, self).__init__()
|
||||||
|
self._activation = torch.tanh if nonlinearity == 'tanh' else torch.relu
|
||||||
|
|
||||||
|
self._num_nodes = num_nodes
|
||||||
|
self._num_units = num_units # hidden dimension
|
||||||
|
self._latent_dim = latent_dim
|
||||||
|
self._gen_layers = gen_layers
|
||||||
|
self.nfe = 0
|
||||||
|
|
||||||
|
self._filter_type = filter_type
|
||||||
|
if(self._filter_type == "unkP"):
|
||||||
|
ode_func_net = utils.create_net(latent_dim, latent_dim, n_units=num_units)
|
||||||
|
utils.init_network_weights(ode_func_net)
|
||||||
|
self.gradient_net = ode_func_net
|
||||||
|
else:
|
||||||
|
self._gcn_step = gcn_step
|
||||||
|
self._gconv_params = LayerParams(self, 'gconv')
|
||||||
|
self._supports = []
|
||||||
|
supports = []
|
||||||
|
supports.append(utils.calculate_random_walk_matrix(adj_mx).T)
|
||||||
|
supports.append(utils.calculate_random_walk_matrix(adj_mx.T).T)
|
||||||
|
|
||||||
|
for support in supports:
|
||||||
|
self._supports.append(self._build_sparse_matrix(support))
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _build_sparse_matrix(L):
|
||||||
|
L = L.tocoo()
|
||||||
|
indices = np.column_stack((L.row, L.col))
|
||||||
|
# this is to ensure row-major ordering to equal torch.sparse.sparse_reorder(L)
|
||||||
|
indices = indices[np.lexsort((indices[:, 0], indices[:, 1]))]
|
||||||
|
L = torch.sparse_coo_tensor(indices.T, L.data, L.shape, device=device)
|
||||||
|
return L
|
||||||
|
|
||||||
|
def forward(self, t_local, y, backwards = False):
|
||||||
|
"""
|
||||||
|
Perform one step in solving ODE. Given current data point y and current time point t_local, returns gradient dy/dt at this time point
|
||||||
|
|
||||||
|
t_local: current time point
|
||||||
|
y: value at the current time point, shape (B, num_nodes * latent_dim)
|
||||||
|
|
||||||
|
:return
|
||||||
|
- Output: A `2-D` tensor with shape `(B, num_nodes * latent_dim)`.
|
||||||
|
"""
|
||||||
|
self.nfe += 1
|
||||||
|
grad = self.get_ode_gradient_nn(t_local, y)
|
||||||
|
if backwards:
|
||||||
|
grad = -grad
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def get_ode_gradient_nn(self, t_local, inputs):
|
||||||
|
if(self._filter_type == "unkP"):
|
||||||
|
grad = self._fc(inputs)
|
||||||
|
elif (self._filter_type == "IncP"):
|
||||||
|
grad = - self.ode_func_net(inputs)
|
||||||
|
else: # default is diffusion process
|
||||||
|
# theta shape: (B, num_nodes * latent_dim)
|
||||||
|
theta = torch.sigmoid(self._gconv(inputs, self._latent_dim, bias_start=1.0))
|
||||||
|
grad = - theta * self.ode_func_net(inputs)
|
||||||
|
return grad
|
||||||
|
|
||||||
|
def ode_func_net(self, inputs):
|
||||||
|
c = inputs
|
||||||
|
for i in range(self._gen_layers):
|
||||||
|
c = self._gconv(c, self._num_units)
|
||||||
|
c = self._activation(c)
|
||||||
|
c = self._gconv(c, self._latent_dim)
|
||||||
|
c = self._activation(c)
|
||||||
|
return c
|
||||||
|
|
||||||
|
def _fc(self, inputs):
|
||||||
|
batch_size = inputs.size()[0]
|
||||||
|
grad = self.gradient_net(inputs.view(batch_size * self._num_nodes, self._latent_dim))
|
||||||
|
return grad.reshape(batch_size, self._num_nodes * self._latent_dim) # (batch_size, num_nodes, latent_dim)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _concat(x, x_):
|
||||||
|
x_ = x_.unsqueeze(0)
|
||||||
|
return torch.cat([x, x_], dim=0)
|
||||||
|
|
||||||
|
def _gconv(self, inputs, output_size, bias_start=0.0):
|
||||||
|
# Reshape input and state to (batch_size, num_nodes, input_dim/state_dim)
|
||||||
|
batch_size = inputs.shape[0]
|
||||||
|
inputs = torch.reshape(inputs, (batch_size, self._num_nodes, -1))
|
||||||
|
# state = torch.reshape(state, (batch_size, self._num_nodes, -1))
|
||||||
|
# inputs_and_state = torch.cat([inputs, state], dim=2)
|
||||||
|
input_size = inputs.size(2)
|
||||||
|
|
||||||
|
x = inputs
|
||||||
|
x0 = x.permute(1, 2, 0) # (num_nodes, total_arg_size, batch_size)
|
||||||
|
x0 = torch.reshape(x0, shape=[self._num_nodes, input_size * batch_size])
|
||||||
|
x = torch.unsqueeze(x0, 0)
|
||||||
|
|
||||||
|
if self._gcn_step == 0:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for support in self._supports:
|
||||||
|
x1 = torch.sparse.mm(support, x0)
|
||||||
|
x = self._concat(x, x1)
|
||||||
|
|
||||||
|
for k in range(2, self._gcn_step + 1):
|
||||||
|
x2 = 2 * torch.sparse.mm(support, x1) - x0
|
||||||
|
x = self._concat(x, x2)
|
||||||
|
x1, x0 = x2, x1
|
||||||
|
|
||||||
|
num_matrices = len(self._supports) * self._gcn_step + 1 # Adds for x itself.
|
||||||
|
x = torch.reshape(x, shape=[num_matrices, self._num_nodes, input_size, batch_size])
|
||||||
|
x = x.permute(3, 1, 2, 0) # (batch_size, num_nodes, input_size, order)
|
||||||
|
x = torch.reshape(x, shape=[batch_size * self._num_nodes, input_size * num_matrices])
|
||||||
|
|
||||||
|
weights = self._gconv_params.get_weights((input_size * num_matrices, output_size))
|
||||||
|
x = torch.matmul(x, weights) # (batch_size * self._num_nodes, output_size)
|
||||||
|
|
||||||
|
biases = self._gconv_params.get_biases(output_size, bias_start)
|
||||||
|
x += biases
|
||||||
|
# Reshape res back to 2D: (batch_size, num_node, state_dim) -> (batch_size, num_node * state_dim)
|
||||||
|
return torch.reshape(x, [batch_size, self._num_nodes * output_size])
|
||||||
|
|
@ -0,0 +1,181 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch.nn.modules.rnn import GRU
|
||||||
|
from models.STDEN.ode_func import ODEFunc
|
||||||
|
from models.STDEN.diffeq_solver import DiffeqSolver
|
||||||
|
from models.STDEN import utils
|
||||||
|
from data.graph_loader import load_graph
|
||||||
|
|
||||||
|
class EncoderAttrs:
|
||||||
|
"""编码器属性配置类"""
|
||||||
|
def __init__(self, config, adj_mx):
|
||||||
|
self.adj_mx = adj_mx
|
||||||
|
self.num_nodes = adj_mx.shape[0]
|
||||||
|
self.num_edges = (adj_mx > 0.).sum()
|
||||||
|
self.gcn_step = int(config.get('gcn_step', 2))
|
||||||
|
self.filter_type = config.get('filter_type', 'default')
|
||||||
|
self.num_rnn_layers = int(config.get('num_rnn_layers', 1))
|
||||||
|
self.rnn_units = int(config.get('rnn_units'))
|
||||||
|
self.latent_dim = int(config.get('latent_dim', 4))
|
||||||
|
|
||||||
|
|
||||||
|
class STDENModel(nn.Module, EncoderAttrs):
|
||||||
|
"""STDEN主模型:时空微分方程网络"""
|
||||||
|
def __init__(self, config):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
adj_mx = load_graph(config)
|
||||||
|
EncoderAttrs.__init__(self, config['model'], adj_mx)
|
||||||
|
|
||||||
|
# 识别网络
|
||||||
|
self.encoder_z0 = Encoder_z0_RNN(config['model'], adj_mx)
|
||||||
|
|
||||||
|
model_kwargs = config['model']
|
||||||
|
# ODE求解器配置
|
||||||
|
self.n_traj_samples = int(model_kwargs.get('n_traj_samples', 1))
|
||||||
|
self.ode_method = model_kwargs.get('ode_method', 'dopri5')
|
||||||
|
self.atol = float(model_kwargs.get('odeint_atol', 1e-4))
|
||||||
|
self.rtol = float(model_kwargs.get('odeint_rtol', 1e-3))
|
||||||
|
self.num_gen_layer = int(model_kwargs.get('gen_layers', 1))
|
||||||
|
self.ode_gen_dim = int(model_kwargs.get('gen_dim', 64))
|
||||||
|
|
||||||
|
# 创建ODE函数和求解器
|
||||||
|
odefunc = ODEFunc(
|
||||||
|
self.ode_gen_dim, self.latent_dim, adj_mx,
|
||||||
|
self.gcn_step, self.num_nodes, filter_type=self.filter_type
|
||||||
|
)
|
||||||
|
|
||||||
|
self.diffeq_solver = DiffeqSolver(
|
||||||
|
odefunc, self.ode_method, self.latent_dim,
|
||||||
|
odeint_rtol=self.rtol, odeint_atol=self.atol
|
||||||
|
)
|
||||||
|
|
||||||
|
# 潜在特征保存设置
|
||||||
|
self.save_latent = bool(model_kwargs.get('save_latent', False))
|
||||||
|
self.latent_feat = None
|
||||||
|
|
||||||
|
# 解码器
|
||||||
|
self.horizon = int(model_kwargs.get('horizon', 1))
|
||||||
|
self.out_feat = int(model_kwargs.get('output_dim', 1))
|
||||||
|
self.decoder = Decoder(
|
||||||
|
self.out_feat, adj_mx, self.num_nodes, self.num_edges
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, inputs, labels=None, batches_seen=None):
|
||||||
|
"""
|
||||||
|
seq2seq前向传播
|
||||||
|
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
||||||
|
:param labels: (horizon, batch_size, num_edges * output_dim)
|
||||||
|
:param batches_seen: 已见批次数量
|
||||||
|
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
||||||
|
"""
|
||||||
|
# 编码初始潜在状态
|
||||||
|
B, T, N, C = inputs.shape
|
||||||
|
inputs = inputs.view(T, B, N * C)
|
||||||
|
first_point_mu, first_point_std = self.encoder_z0(inputs)
|
||||||
|
|
||||||
|
# 采样轨迹
|
||||||
|
means_z0 = first_point_mu.repeat(self.n_traj_samples, 1, 1)
|
||||||
|
sigma_z0 = first_point_std.repeat(self.n_traj_samples, 1, 1)
|
||||||
|
first_point_enc = utils.sample_standard_gaussian(means_z0, sigma_z0)
|
||||||
|
|
||||||
|
# 时间步预测
|
||||||
|
time_steps_to_predict = torch.arange(start=0, end=self.horizon, step=1).float()
|
||||||
|
time_steps_to_predict = time_steps_to_predict / len(time_steps_to_predict)
|
||||||
|
|
||||||
|
# ODE求解
|
||||||
|
sol_ys, fe = self.diffeq_solver(first_point_enc, time_steps_to_predict)
|
||||||
|
|
||||||
|
if self.save_latent:
|
||||||
|
self.latent_feat = torch.mean(sol_ys.detach(), axis=1)
|
||||||
|
# 解码输出
|
||||||
|
outputs = self.decoder(sol_ys)
|
||||||
|
|
||||||
|
outputs = outputs.view(B, T, N, C)
|
||||||
|
|
||||||
|
return outputs, fe
|
||||||
|
|
||||||
|
|
||||||
|
class Encoder_z0_RNN(nn.Module, EncoderAttrs):
|
||||||
|
"""RNN编码器:将输入序列编码为初始潜在状态"""
|
||||||
|
def __init__(self, config, adj_mx):
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
EncoderAttrs.__init__(self, config, adj_mx)
|
||||||
|
|
||||||
|
self.recg_type = config.get('recg_type', 'gru')
|
||||||
|
self.input_dim = int(config.get('input_dim', 1))
|
||||||
|
|
||||||
|
if self.recg_type == 'gru':
|
||||||
|
self.gru_rnn = GRU(self.input_dim, self.rnn_units)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("只支持'gru'识别网络")
|
||||||
|
|
||||||
|
# 隐藏状态到z0的映射
|
||||||
|
self.inv_grad = utils.graph_grad(adj_mx).transpose(-2, -1)
|
||||||
|
self.inv_grad[self.inv_grad != 0.] = 0.5
|
||||||
|
|
||||||
|
self.hiddens_to_z0 = nn.Sequential(
|
||||||
|
nn.Linear(self.rnn_units, 50),
|
||||||
|
nn.Tanh(),
|
||||||
|
nn.Linear(50, self.latent_dim * 2)
|
||||||
|
)
|
||||||
|
utils.init_network_weights(self.hiddens_to_z0)
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""
|
||||||
|
编码器前向传播
|
||||||
|
:param inputs: (seq_len, batch_size, num_edges * input_dim)
|
||||||
|
:return: mean, std: (1, batch_size, latent_dim)
|
||||||
|
"""
|
||||||
|
seq_len, batch_size = inputs.size(0), inputs.size(1)
|
||||||
|
|
||||||
|
# 重塑输入并处理
|
||||||
|
inputs = inputs.reshape(seq_len, batch_size, self.num_nodes, self.input_dim)
|
||||||
|
inputs = inputs.reshape(seq_len, batch_size * self.num_nodes, self.input_dim)
|
||||||
|
|
||||||
|
# GRU处理
|
||||||
|
outputs, _ = self.gru_rnn(inputs)
|
||||||
|
last_output = outputs[-1]
|
||||||
|
|
||||||
|
# 重塑并转换维度
|
||||||
|
last_output = torch.reshape(last_output, (batch_size, self.num_nodes, -1))
|
||||||
|
last_output = torch.transpose(last_output, (-2, -1))
|
||||||
|
last_output = torch.matmul(last_output, self.inv_grad).transpose(-2, -1)
|
||||||
|
|
||||||
|
# 生成均值和标准差
|
||||||
|
mean, std = utils.split_last_dim(self.hiddens_to_z0(last_output))
|
||||||
|
mean = mean.reshape(batch_size, -1)
|
||||||
|
std = std.reshape(batch_size, -1).abs()
|
||||||
|
|
||||||
|
return mean.unsqueeze(0), std.unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
class Decoder(nn.Module):
|
||||||
|
"""解码器:将潜在状态解码为输出"""
|
||||||
|
def __init__(self, output_dim, adj_mx, num_nodes, num_edges):
|
||||||
|
super(Decoder, self).__init__()
|
||||||
|
self.num_nodes = num_nodes
|
||||||
|
self.num_edges = num_edges
|
||||||
|
self.grap_grad = utils.graph_grad(adj_mx)
|
||||||
|
self.output_dim = output_dim
|
||||||
|
|
||||||
|
def forward(self, inputs):
|
||||||
|
"""
|
||||||
|
:param inputs: (horizon, n_traj_samples, batch_size, num_nodes * latent_dim)
|
||||||
|
:return: outputs: (horizon, batch_size, num_edges * output_dim)
|
||||||
|
"""
|
||||||
|
horizon, n_traj_samples, batch_size = inputs.size()[:3]
|
||||||
|
|
||||||
|
# 重塑输入
|
||||||
|
inputs = inputs.reshape(horizon, n_traj_samples, batch_size, self.num_nodes, -1).transpose(-2, -1)
|
||||||
|
latent_dim = inputs.size(-2)
|
||||||
|
|
||||||
|
# 图梯度变换:从节点到边
|
||||||
|
outputs = torch.matmul(inputs, self.grap_grad)
|
||||||
|
|
||||||
|
# 重塑并平均采样轨迹
|
||||||
|
outputs = outputs.reshape(horizon, n_traj_samples, batch_size, latent_dim, self.num_nodes, self.output_dim)
|
||||||
|
outputs = torch.mean(torch.mean(outputs, axis=3), axis=1)
|
||||||
|
outputs = outputs.reshape(horizon, batch_size, -1)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
@ -0,0 +1,234 @@
|
||||||
|
import logging
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import scipy.sparse as sp
|
||||||
|
import sys
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DataLoader(object):
|
||||||
|
def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param xs:
|
||||||
|
:param ys:
|
||||||
|
:param batch_size:
|
||||||
|
:param pad_with_last_sample: pad with the last sample to make number of samples divisible to batch_size.
|
||||||
|
"""
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.current_ind = 0
|
||||||
|
if pad_with_last_sample:
|
||||||
|
num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
|
||||||
|
x_padding = np.repeat(xs[-1:], num_padding, axis=0)
|
||||||
|
y_padding = np.repeat(ys[-1:], num_padding, axis=0)
|
||||||
|
xs = np.concatenate([xs, x_padding], axis=0)
|
||||||
|
ys = np.concatenate([ys, y_padding], axis=0)
|
||||||
|
self.size = len(xs)
|
||||||
|
self.num_batch = int(self.size // self.batch_size)
|
||||||
|
if shuffle:
|
||||||
|
permutation = np.random.permutation(self.size)
|
||||||
|
xs, ys = xs[permutation], ys[permutation]
|
||||||
|
self.xs = xs
|
||||||
|
self.ys = ys
|
||||||
|
|
||||||
|
def get_iterator(self):
|
||||||
|
self.current_ind = 0
|
||||||
|
|
||||||
|
def _wrapper():
|
||||||
|
while self.current_ind < self.num_batch:
|
||||||
|
start_ind = self.batch_size * self.current_ind
|
||||||
|
end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
|
||||||
|
x_i = self.xs[start_ind: end_ind, ...]
|
||||||
|
y_i = self.ys[start_ind: end_ind, ...]
|
||||||
|
yield (x_i, y_i)
|
||||||
|
self.current_ind += 1
|
||||||
|
|
||||||
|
return _wrapper()
|
||||||
|
|
||||||
|
|
||||||
|
class StandardScaler:
|
||||||
|
"""
|
||||||
|
Standard the input
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, mean, std):
|
||||||
|
self.mean = mean
|
||||||
|
self.std = std
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return (data - self.mean) / self.std
|
||||||
|
|
||||||
|
def inverse_transform(self, data):
|
||||||
|
return (data * self.std) + self.mean
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_random_walk_matrix(adj_mx):
|
||||||
|
adj_mx = sp.coo_matrix(adj_mx)
|
||||||
|
d = np.array(adj_mx.sum(1))
|
||||||
|
d_inv = np.power(d, -1).flatten()
|
||||||
|
d_inv[np.isinf(d_inv)] = 0.
|
||||||
|
d_mat_inv = sp.diags(d_inv)
|
||||||
|
random_walk_mx = d_mat_inv.dot(adj_mx).tocoo()
|
||||||
|
return random_walk_mx
|
||||||
|
|
||||||
|
|
||||||
|
def config_logging(log_dir, log_filename='info.log', level=logging.INFO):
|
||||||
|
# Add file handler and stdout handler
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
# Create the log directory if necessary.
|
||||||
|
try:
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
file_handler.setLevel(level=level)
|
||||||
|
# Add console handler.
|
||||||
|
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
console_handler.setLevel(level=level)
|
||||||
|
logging.basicConfig(handlers=[file_handler, console_handler], level=level)
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(log_dir, name, log_filename='info.log', level=logging.INFO):
|
||||||
|
logger = logging.getLogger(name)
|
||||||
|
logger.setLevel(level)
|
||||||
|
# Add file handler and stdout handler
|
||||||
|
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
|
||||||
|
file_handler = logging.FileHandler(os.path.join(log_dir, log_filename))
|
||||||
|
file_handler.setFormatter(formatter)
|
||||||
|
# Add console handler.
|
||||||
|
console_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
||||||
|
console_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
console_handler.setFormatter(console_formatter)
|
||||||
|
logger.addHandler(file_handler)
|
||||||
|
logger.addHandler(console_handler)
|
||||||
|
# Add google cloud log handler
|
||||||
|
logger.info('Log directory: %s', log_dir)
|
||||||
|
return logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_log_dir(kwargs):
|
||||||
|
log_dir = kwargs['train'].get('log_dir')
|
||||||
|
if log_dir is None:
|
||||||
|
batch_size = kwargs['data'].get('batch_size')
|
||||||
|
|
||||||
|
filter_type = kwargs['model'].get('filter_type')
|
||||||
|
gcn_step = kwargs['model'].get('gcn_step')
|
||||||
|
horizon = kwargs['model'].get('horizon')
|
||||||
|
latent_dim = kwargs['model'].get('latent_dim')
|
||||||
|
n_traj_samples = kwargs['model'].get('n_traj_samples')
|
||||||
|
ode_method = kwargs['model'].get('ode_method')
|
||||||
|
|
||||||
|
seq_len = kwargs['model'].get('seq_len')
|
||||||
|
rnn_units = kwargs['model'].get('rnn_units')
|
||||||
|
recg_type = kwargs['model'].get('recg_type')
|
||||||
|
|
||||||
|
if filter_type == 'unkP':
|
||||||
|
filter_type_abbr = 'UP'
|
||||||
|
elif filter_type == 'IncP':
|
||||||
|
filter_type_abbr = 'NV'
|
||||||
|
else:
|
||||||
|
filter_type_abbr = 'DF'
|
||||||
|
|
||||||
|
run_id = 'STDEN_%s-%d_%s-%d_L-%d_N-%d_M-%s_bs-%d_%d-%d_%s/' % (
|
||||||
|
recg_type, rnn_units, filter_type_abbr, gcn_step, latent_dim, n_traj_samples, ode_method, batch_size,
|
||||||
|
seq_len, horizon, time.strftime('%m%d%H%M%S'))
|
||||||
|
base_dir = kwargs.get('log_base_dir')
|
||||||
|
log_dir = os.path.join(base_dir, run_id)
|
||||||
|
if not os.path.exists(log_dir):
|
||||||
|
os.makedirs(log_dir)
|
||||||
|
return log_dir
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(dataset_dir, batch_size, val_batch_size=None, **kwargs):
|
||||||
|
if ('BJ' in dataset_dir):
|
||||||
|
data = dict(np.load(os.path.join(dataset_dir, 'flow.npz'))) # convert readonly NpzFile to writable dict Object
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
data['x_' + category] = data['x_' + category] # [..., :4] # ignore the time index
|
||||||
|
else:
|
||||||
|
data = {}
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
cat_data = np.load(os.path.join(dataset_dir, category + '.npz'))
|
||||||
|
data['x_' + category] = cat_data['x']
|
||||||
|
data['y_' + category] = cat_data['y']
|
||||||
|
scaler = StandardScaler(mean=data['x_train'].mean(), std=data['x_train'].std())
|
||||||
|
# Data format
|
||||||
|
for category in ['train', 'val', 'test']:
|
||||||
|
data['x_' + category] = scaler.transform(data['x_' + category])
|
||||||
|
data['y_' + category] = scaler.transform(data['y_' + category])
|
||||||
|
data['train_loader'] = DataLoader(data['x_train'], data['y_train'], batch_size, shuffle=True)
|
||||||
|
data['val_loader'] = DataLoader(data['x_val'], data['y_val'], val_batch_size, shuffle=False)
|
||||||
|
data['test_loader'] = DataLoader(data['x_test'], data['y_test'], val_batch_size, shuffle=False)
|
||||||
|
data['scaler'] = scaler
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_graph_data(pkl_filename):
|
||||||
|
adj_mx = np.load(pkl_filename)
|
||||||
|
return adj_mx
|
||||||
|
|
||||||
|
|
||||||
|
def graph_grad(adj_mx):
|
||||||
|
"""Fetch the graph gradient operator."""
|
||||||
|
num_nodes = adj_mx.shape[0]
|
||||||
|
|
||||||
|
num_edges = (adj_mx > 0.).sum()
|
||||||
|
grad = torch.zeros(num_nodes, num_edges)
|
||||||
|
e = 0
|
||||||
|
for i in range(num_nodes):
|
||||||
|
for j in range(num_nodes):
|
||||||
|
if adj_mx[i, j] == 0:
|
||||||
|
continue
|
||||||
|
grad[i, e] = 1.
|
||||||
|
grad[j, e] = -1.
|
||||||
|
e += 1
|
||||||
|
return grad
|
||||||
|
|
||||||
|
|
||||||
|
def init_network_weights(net, std=0.1):
|
||||||
|
"""
|
||||||
|
Just for nn.Linear net.
|
||||||
|
"""
|
||||||
|
for m in net.modules():
|
||||||
|
if isinstance(m, nn.Linear):
|
||||||
|
nn.init.normal_(m.weight, mean=0, std=std)
|
||||||
|
nn.init.constant_(m.bias, val=0)
|
||||||
|
|
||||||
|
|
||||||
|
def split_last_dim(data):
|
||||||
|
last_dim = data.size()[-1]
|
||||||
|
last_dim = last_dim // 2
|
||||||
|
|
||||||
|
res = data[..., :last_dim], data[..., last_dim:]
|
||||||
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
def get_device(tensor):
|
||||||
|
device = torch.device("cpu")
|
||||||
|
if tensor.is_cuda:
|
||||||
|
device = tensor.get_device()
|
||||||
|
return device
|
||||||
|
|
||||||
|
|
||||||
|
def sample_standard_gaussian(mu, sigma):
|
||||||
|
device = get_device(mu)
|
||||||
|
|
||||||
|
d = torch.distributions.normal.Normal(torch.Tensor([0.]).to(device), torch.Tensor([1.]).to(device))
|
||||||
|
r = d.sample(mu.size()).squeeze(-1)
|
||||||
|
return r * sigma.float() + mu.float()
|
||||||
|
|
||||||
|
|
||||||
|
def create_net(n_inputs, n_outputs, n_layers=0, n_units=100, nonlinear=nn.Tanh):
|
||||||
|
layers = [nn.Linear(n_inputs, n_units)]
|
||||||
|
for i in range(n_layers):
|
||||||
|
layers.append(nonlinear())
|
||||||
|
layers.append(nn.Linear(n_units, n_units))
|
||||||
|
|
||||||
|
layers.append(nonlinear())
|
||||||
|
layers.append(nn.Linear(n_units, n_outputs))
|
||||||
|
return nn.Sequential(*layers)
|
||||||
|
|
@ -0,0 +1,180 @@
|
||||||
|
import torch
|
||||||
|
import math
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from models.STGODE.odegcn import ODEG
|
||||||
|
from models.STGODE.adj import get_A_hat
|
||||||
|
|
||||||
|
class Chomp1d(nn.Module):
|
||||||
|
"""
|
||||||
|
extra dimension will be added by padding, remove it
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, chomp_size):
|
||||||
|
super(Chomp1d, self).__init__()
|
||||||
|
self.chomp_size = chomp_size
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
return x[:, :, :, :-self.chomp_size].contiguous()
|
||||||
|
|
||||||
|
|
||||||
|
class TemporalConvNet(nn.Module):
|
||||||
|
"""
|
||||||
|
time dilation convolution
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_inputs : channel's number of input data's feature
|
||||||
|
num_channels : numbers of data feature tranform channels, the last is the output channel
|
||||||
|
kernel_size : using 1d convolution, so the real kernel is (1, kernel_size)
|
||||||
|
"""
|
||||||
|
super(TemporalConvNet, self).__init__()
|
||||||
|
layers = []
|
||||||
|
num_levels = len(num_channels)
|
||||||
|
for i in range(num_levels):
|
||||||
|
dilation_size = 2 ** i
|
||||||
|
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||||
|
out_channels = num_channels[i]
|
||||||
|
padding = (kernel_size - 1) * dilation_size
|
||||||
|
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
|
||||||
|
padding=(0, padding))
|
||||||
|
self.conv.weight.data.normal_(0, 0.01)
|
||||||
|
self.chomp = Chomp1d(padding)
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
||||||
|
|
||||||
|
self.network = nn.Sequential(*layers)
|
||||||
|
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
|
||||||
|
if self.downsample:
|
||||||
|
self.downsample.weight.data.normal_(0, 0.01)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
like ResNet
|
||||||
|
Args:
|
||||||
|
X : input data of shape (B, N, T, F)
|
||||||
|
"""
|
||||||
|
# permute shape to (B, F, N, T)
|
||||||
|
y = x.permute(0, 3, 1, 2)
|
||||||
|
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
|
||||||
|
y = y.permute(0, 2, 3, 1)
|
||||||
|
return y
|
||||||
|
|
||||||
|
|
||||||
|
class GCN(nn.Module):
|
||||||
|
def __init__(self, A_hat, in_channels, out_channels, ):
|
||||||
|
super(GCN, self).__init__()
|
||||||
|
self.A_hat = A_hat
|
||||||
|
self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
stdv = 1. / math.sqrt(self.theta.shape[1])
|
||||||
|
self.theta.data.uniform_(-stdv, stdv)
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X)
|
||||||
|
return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta))
|
||||||
|
|
||||||
|
|
||||||
|
class STGCNBlock(nn.Module):
|
||||||
|
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
in_channels: Number of input features at each node in each time step.
|
||||||
|
out_channels: a list of feature channels in timeblock, the last is output feature channel
|
||||||
|
num_nodes: Number of nodes in the graph
|
||||||
|
A_hat: the normalized adjacency matrix
|
||||||
|
"""
|
||||||
|
super(STGCNBlock, self).__init__()
|
||||||
|
self.A_hat = A_hat
|
||||||
|
self.temporal1 = TemporalConvNet(num_inputs=in_channels,
|
||||||
|
num_channels=out_channels)
|
||||||
|
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
|
||||||
|
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1],
|
||||||
|
num_channels=out_channels)
|
||||||
|
self.batch_norm = nn.BatchNorm2d(num_nodes)
|
||||||
|
|
||||||
|
def forward(self, X):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features)
|
||||||
|
Return:
|
||||||
|
Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1])
|
||||||
|
"""
|
||||||
|
t = self.temporal1(X)
|
||||||
|
t = self.odeg(t)
|
||||||
|
t = self.temporal2(F.relu(t))
|
||||||
|
|
||||||
|
return self.batch_norm(t)
|
||||||
|
|
||||||
|
|
||||||
|
class ODEGCN(nn.Module):
|
||||||
|
""" the overall network framework """
|
||||||
|
|
||||||
|
def __init__(self, config):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
num_nodes : number of nodes in the graph
|
||||||
|
num_features : number of features at each node in each time step
|
||||||
|
num_timesteps_input : number of past time steps fed into the network
|
||||||
|
num_timesteps_output : desired number of future time steps output by the network
|
||||||
|
A_sp_hat : nomarlized adjacency spatial matrix
|
||||||
|
A_se_hat : nomarlized adjacency semantic matrix
|
||||||
|
"""
|
||||||
|
|
||||||
|
super(ODEGCN, self).__init__()
|
||||||
|
args = config['model']
|
||||||
|
num_nodes = config['data']['num_nodes']
|
||||||
|
num_features = args['num_features']
|
||||||
|
num_timesteps_input = args['history']
|
||||||
|
num_timesteps_output = args['horizon']
|
||||||
|
A_sp_hat, A_se_hat = get_A_hat(config)
|
||||||
|
|
||||||
|
# spatial graph
|
||||||
|
self.sp_blocks = nn.ModuleList(
|
||||||
|
[nn.Sequential(
|
||||||
|
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
|
||||||
|
num_nodes=num_nodes, A_hat=A_sp_hat),
|
||||||
|
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
|
||||||
|
num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
|
||||||
|
])
|
||||||
|
# semantic graph
|
||||||
|
self.se_blocks = nn.ModuleList([nn.Sequential(
|
||||||
|
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
|
||||||
|
num_nodes=num_nodes, A_hat=A_se_hat),
|
||||||
|
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
|
||||||
|
num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
|
||||||
|
])
|
||||||
|
|
||||||
|
self.pred = nn.Sequential(
|
||||||
|
nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(num_timesteps_output * 32, num_timesteps_output)
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
|
||||||
|
Returns:
|
||||||
|
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
|
||||||
|
"""
|
||||||
|
x = x[..., 0:1].permute(0, 2, 1, 3)
|
||||||
|
outs = []
|
||||||
|
# spatial graph
|
||||||
|
for blk in self.sp_blocks:
|
||||||
|
outs.append(blk(x))
|
||||||
|
# semantic graph
|
||||||
|
for blk in self.se_blocks:
|
||||||
|
outs.append(blk(x))
|
||||||
|
outs = torch.stack(outs)
|
||||||
|
x = torch.max(outs, dim=0)[0]
|
||||||
|
x = x.reshape((x.shape[0], x.shape[1], -1))
|
||||||
|
|
||||||
|
return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)
|
||||||
|
|
@ -0,0 +1,132 @@
|
||||||
|
import csv
|
||||||
|
import os
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from fastdtw import fastdtw
|
||||||
|
from tqdm import tqdm
|
||||||
|
import torch
|
||||||
|
|
||||||
|
files = {
|
||||||
|
358: ['PEMS03/PEMS03.npz', 'PEMS03/PEMS03.csv'],
|
||||||
|
307: ['PEMS04/PEMS04.npz', 'PEMS04/PEMS04.csv'],
|
||||||
|
883: ['PEMS07/PEMS07.npz', 'PEMS07/PEMS07.csv'],
|
||||||
|
170: ['PEMS08/PEMS08.npz', 'PEMS08/PEMS08.csv'],
|
||||||
|
# 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
|
||||||
|
# 'pemsD7M': ['PEMSD7M/PEMSD7M.npz', 'PEMSD7M/distance.csv'],
|
||||||
|
# 'pemsD7L': ['PEMSD7L/PEMSD7L.npz', 'PEMSD7L/distance.csv']
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_A_hat(config):
|
||||||
|
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
|
||||||
|
|
||||||
|
Args:
|
||||||
|
sigma1: float, default=0.1, sigma for the semantic matrix
|
||||||
|
sigma2: float, default=10, sigma for the spatial matrix
|
||||||
|
thres1: float, default=0.6, the threshold for the semantic matrix
|
||||||
|
thres2: float, default=0.5, the threshold for the spatial matrix
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
data: tensor, T * N * 1
|
||||||
|
dtw_matrix: array, semantic adjacency matrix
|
||||||
|
sp_matrix: array, spatial adjacency matrix
|
||||||
|
"""
|
||||||
|
file_path = config['data']['graph_pkl_filename']
|
||||||
|
filename = config['basic']['dataset']
|
||||||
|
dataset_path = config['data']['dataset_dir']
|
||||||
|
args = config['model']
|
||||||
|
|
||||||
|
data = np.load(file_path)
|
||||||
|
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
num_node = data.shape[1]
|
||||||
|
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||||
|
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||||
|
data = (data - mean_value) / std_value
|
||||||
|
|
||||||
|
# 计算dtw_distance, 如果存在缓存则直接读取缓存
|
||||||
|
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'):
|
||||||
|
data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
|
||||||
|
axis=0)
|
||||||
|
data_mean = data_mean.squeeze().T
|
||||||
|
dtw_distance = np.zeros((num_node, num_node))
|
||||||
|
for i in tqdm(range(num_node)):
|
||||||
|
for j in range(i, num_node):
|
||||||
|
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
|
||||||
|
for i in range(num_node):
|
||||||
|
for j in range(i):
|
||||||
|
dtw_distance[i][j] = dtw_distance[j][i]
|
||||||
|
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance)
|
||||||
|
|
||||||
|
dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy')
|
||||||
|
|
||||||
|
mean = np.mean(dist_matrix)
|
||||||
|
std = np.std(dist_matrix)
|
||||||
|
dist_matrix = (dist_matrix - mean) / std
|
||||||
|
sigma = args['sigma1']
|
||||||
|
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
|
||||||
|
dtw_matrix = np.zeros_like(dist_matrix)
|
||||||
|
dtw_matrix[dist_matrix > args['thres1']] = 1
|
||||||
|
|
||||||
|
# 计算spatial_distance, 如果存在缓存则直接读取缓存
|
||||||
|
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
|
||||||
|
if num_node == 358:
|
||||||
|
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
|
||||||
|
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
|
||||||
|
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||||
|
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||||
|
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
start = int(id_dict[row[0]])
|
||||||
|
end = int(id_dict[row[1]])
|
||||||
|
dist_matrix[start][end] = float(row[2])
|
||||||
|
dist_matrix[end][start] = float(row[2])
|
||||||
|
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
||||||
|
else:
|
||||||
|
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||||
|
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||||
|
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
start = int(row[0])
|
||||||
|
end = int(row[1])
|
||||||
|
dist_matrix[start][end] = float(row[2])
|
||||||
|
dist_matrix[end][start] = float(row[2])
|
||||||
|
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
||||||
|
# normalization
|
||||||
|
std = np.std(dist_matrix[dist_matrix != float('inf')])
|
||||||
|
mean = np.mean(dist_matrix[dist_matrix != float('inf')])
|
||||||
|
dist_matrix = (dist_matrix - mean) / std
|
||||||
|
sigma = args['sigma2']
|
||||||
|
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
|
||||||
|
sp_matrix[sp_matrix < args['thres2']] = 0
|
||||||
|
|
||||||
|
return (get_normalized_adj(dtw_matrix).to(config['basic']['device']),
|
||||||
|
get_normalized_adj(sp_matrix).to(config['basic']['device']))
|
||||||
|
|
||||||
|
|
||||||
|
def get_normalized_adj(A):
|
||||||
|
"""
|
||||||
|
Returns a tensor, the degree normalized adjacency matrix.
|
||||||
|
"""
|
||||||
|
alpha = 0.8
|
||||||
|
D = np.array(np.sum(A, axis=1)).reshape((-1,))
|
||||||
|
D[D <= 10e-5] = 10e-5 # Prevent infs
|
||||||
|
diag = np.reciprocal(np.sqrt(D))
|
||||||
|
A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
|
||||||
|
diag.reshape((1, -1)))
|
||||||
|
A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave)
|
||||||
|
return torch.from_numpy(A_reg.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
config = {
|
||||||
|
'sigma1': 0.1,
|
||||||
|
'sigma2': 10,
|
||||||
|
'thres1': 0.6,
|
||||||
|
'thres2': 0.5,
|
||||||
|
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
|
}
|
||||||
|
|
||||||
|
for nodes in [358, 170, 883]:
|
||||||
|
args = {'num_nodes': nodes, **config}
|
||||||
|
get_A_hat(args)
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
# Whether use adjoint method or not.
|
||||||
|
adjoint = False
|
||||||
|
if adjoint:
|
||||||
|
from torchdiffeq import odeint_adjoint as odeint
|
||||||
|
else:
|
||||||
|
from torchdiffeq import odeint
|
||||||
|
|
||||||
|
|
||||||
|
# Define the ODE function.
|
||||||
|
# Input:
|
||||||
|
# --- t: A tensor with shape [], meaning the current time.
|
||||||
|
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
|
||||||
|
# Output:
|
||||||
|
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
|
||||||
|
class ODEFunc(nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, feature_dim, temporal_dim, adj):
|
||||||
|
super(ODEFunc, self).__init__()
|
||||||
|
self.adj = adj
|
||||||
|
self.x0 = None
|
||||||
|
self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1]))
|
||||||
|
self.beta = 0.6
|
||||||
|
self.w = nn.Parameter(torch.eye(feature_dim))
|
||||||
|
self.d = nn.Parameter(torch.zeros(feature_dim) + 1)
|
||||||
|
self.w2 = nn.Parameter(torch.eye(temporal_dim))
|
||||||
|
self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1)
|
||||||
|
|
||||||
|
def forward(self, t, x):
|
||||||
|
alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
|
||||||
|
xa = torch.einsum('ij, kjlm->kilm', self.adj, x)
|
||||||
|
|
||||||
|
# ensure the eigenvalues to be less than 1
|
||||||
|
d = torch.clamp(self.d, min=0, max=1)
|
||||||
|
w = torch.mm(self.w * d, torch.t(self.w))
|
||||||
|
xw = torch.einsum('ijkl, lm->ijkm', x, w)
|
||||||
|
|
||||||
|
d2 = torch.clamp(self.d2, min=0, max=1)
|
||||||
|
w2 = torch.mm(self.w2 * d2, torch.t(self.w2))
|
||||||
|
xw2 = torch.einsum('ijkl, km->ijml', x, w2)
|
||||||
|
|
||||||
|
f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0
|
||||||
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
class ODEblock(nn.Module):
|
||||||
|
def __init__(self, odefunc, t=torch.tensor([0,1])):
|
||||||
|
super(ODEblock, self).__init__()
|
||||||
|
self.t = t
|
||||||
|
self.odefunc = odefunc
|
||||||
|
|
||||||
|
def set_x0(self, x0):
|
||||||
|
self.odefunc.x0 = x0.clone().detach()
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
t = self.t.type_as(x)
|
||||||
|
z = odeint(self.odefunc, x, t, method='euler')[1]
|
||||||
|
return z
|
||||||
|
|
||||||
|
|
||||||
|
# Define the ODEGCN model.
|
||||||
|
class ODEG(nn.Module):
|
||||||
|
def __init__(self, feature_dim, temporal_dim, adj, time):
|
||||||
|
super(ODEG, self).__init__()
|
||||||
|
self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj), t=torch.tensor([0, time]))
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
self.odeblock.set_x0(x)
|
||||||
|
z = self.odeblock(x)
|
||||||
|
return F.relu(z)
|
||||||
|
|
@ -1,6 +1,12 @@
|
||||||
|
from models.STDEN.stden_model import STDENModel
|
||||||
|
from models.STGODE.STGODE import ODEGCN
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
model = None
|
model = None
|
||||||
|
match model_name:
|
||||||
|
case 'STDEN':
|
||||||
|
model = STDENModel(config)
|
||||||
|
case 'STGODE':
|
||||||
|
model = ODEGCN(config)
|
||||||
return model
|
return model
|
||||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -0,0 +1,576 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
|
||||||
|
scalers, logger, lr_scheduler=None):
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
self.scalers = scalers # 现在是多个标准化器的列表
|
||||||
|
self.args = config['train']
|
||||||
|
self.logger = logger
|
||||||
|
self.args['device'] = config['basic']['device']
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.train_per_epoch = len(train_loader)
|
||||||
|
self.val_per_epoch = len(val_loader) if val_loader else 0
|
||||||
|
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
|
||||||
|
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
||||||
|
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
self.c = []
|
||||||
|
self.res, self.keys = [], []
|
||||||
|
|
||||||
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
|
if mode == 'train':
|
||||||
|
self.model.train()
|
||||||
|
optimizer_step = True
|
||||||
|
else:
|
||||||
|
self.model.eval()
|
||||||
|
optimizer_step = False
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
epoch_time = time.time()
|
||||||
|
|
||||||
|
# 清空nfe数据收集
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.clear()
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
label = target[..., :self.args['output_dim']]
|
||||||
|
output, fe = self.model(data)
|
||||||
|
|
||||||
|
if self.args['real_value']:
|
||||||
|
# 只对输出维度进行反归一化
|
||||||
|
output = self._inverse_transform_output(output)
|
||||||
|
|
||||||
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.append([*fe, loss.item()])
|
||||||
|
# 记录FE信息
|
||||||
|
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||||
|
|
||||||
|
if optimizer_step and self.optimizer is not None:
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if self.args['grad_norm']:
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
self.logger.info(
|
||||||
|
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
|
||||||
|
|
||||||
|
# 更新 tqdm 的进度
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(dataloader)
|
||||||
|
self.logger.logger.info(
|
||||||
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
||||||
|
self.keys.append(epoch)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
def _inverse_transform_output(self, output):
|
||||||
|
"""
|
||||||
|
只对输出维度进行反归一化
|
||||||
|
假设输出数据形状为 [batch, horizon, nodes, features]
|
||||||
|
只对前output_dim个特征进行反归一化
|
||||||
|
"""
|
||||||
|
if not self.args['real_value']:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# 获取输出维度的数量
|
||||||
|
output_dim = self.args['output_dim']
|
||||||
|
|
||||||
|
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||||
|
if output_dim <= len(self.scalers):
|
||||||
|
# 对每个输出特征分别进行反归一化
|
||||||
|
for feature_idx in range(output_dim):
|
||||||
|
if feature_idx < len(self.scalers):
|
||||||
|
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||||
|
for feature_idx in range(len(self.scalers)):
|
||||||
|
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.train_loader, 'train')
|
||||||
|
|
||||||
|
def val_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
|
||||||
|
|
||||||
|
def test_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.test_loader, 'test')
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
best_model, best_test_model = None, None
|
||||||
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.logger.logger.info("Training process started")
|
||||||
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
val_epoch_loss = self.val_epoch(epoch)
|
||||||
|
test_epoch_loss = self.test_epoch(epoch)
|
||||||
|
|
||||||
|
if train_epoch_loss > 1e6:
|
||||||
|
self.logger.logger.warning('Gradient explosion detected. Ending...')
|
||||||
|
break
|
||||||
|
|
||||||
|
if val_epoch_loss < best_loss:
|
||||||
|
best_loss = val_epoch_loss
|
||||||
|
not_improved_count = 0
|
||||||
|
best_model = copy.deepcopy(self.model.state_dict())
|
||||||
|
torch.save(best_model, self.best_path)
|
||||||
|
self.logger.logger.info('Best validation model saved!')
|
||||||
|
else:
|
||||||
|
not_improved_count += 1
|
||||||
|
|
||||||
|
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
|
||||||
|
self.logger.logger.info(
|
||||||
|
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if test_epoch_loss < best_test_loss:
|
||||||
|
best_test_loss = test_epoch_loss
|
||||||
|
best_test_model = copy.deepcopy(self.model.state_dict())
|
||||||
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
||||||
|
self._save_nfe_data()
|
||||||
|
|
||||||
|
if not self.args['debug']:
|
||||||
|
torch.save(best_model, self.best_path)
|
||||||
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
self.model.load_state_dict(best_model)
|
||||||
|
self.logger.logger.info("Testing on best validation model")
|
||||||
|
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
|
||||||
|
|
||||||
|
self.model.load_state_dict(best_test_model)
|
||||||
|
self.logger.logger.info("Testing on best test model")
|
||||||
|
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
|
||||||
|
if path:
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
model.to(args['device'])
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
y_pred, y_true = [], []
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
c = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, target in data_loader:
|
||||||
|
label = target[..., :args['output_dim']]
|
||||||
|
output, fe = model(data)
|
||||||
|
y_pred.append(output)
|
||||||
|
y_true.append(label)
|
||||||
|
|
||||||
|
# 收集nfe数据
|
||||||
|
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
||||||
|
|
||||||
|
if args['real_value']:
|
||||||
|
# 只对输出维度进行反归一化
|
||||||
|
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
||||||
|
else:
|
||||||
|
y_pred = torch.cat(y_pred, dim=0)
|
||||||
|
y_true = torch.cat(y_true, dim=0)
|
||||||
|
|
||||||
|
# 计算每个时间步的指标
|
||||||
|
for t in range(y_true.shape[1]):
|
||||||
|
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||||
|
args['mae_thresh'], args['mape_thresh'])
|
||||||
|
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
|
|
||||||
|
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||||
|
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
||||||
|
Trainer._save_nfe_data_static(c, model, logger)
|
||||||
|
|
||||||
|
# 只在需要时生成可视化图片
|
||||||
|
if generate_viz:
|
||||||
|
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
||||||
|
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
|
||||||
|
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||||
|
target_node=1, num_samples=10, scalers=scalers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inverse_transform_output_static(output, args, scalers):
|
||||||
|
"""
|
||||||
|
静态方法:只对输出维度进行反归一化
|
||||||
|
"""
|
||||||
|
if not args['real_value']:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# 获取输出维度的数量
|
||||||
|
output_dim = args['output_dim']
|
||||||
|
|
||||||
|
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||||
|
if output_dim <= len(scalers):
|
||||||
|
# 对每个输出特征分别进行反归一化
|
||||||
|
for feature_idx in range(output_dim):
|
||||||
|
if feature_idx < len(scalers):
|
||||||
|
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||||
|
for feature_idx in range(len(scalers)):
|
||||||
|
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
|
||||||
|
"""
|
||||||
|
生成节点预测可视化图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred: 预测值
|
||||||
|
y_true: 真实值
|
||||||
|
logger: 日志记录器
|
||||||
|
save_dir: 保存目录
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 设置matplotlib配置,减少字体查找输出
|
||||||
|
matplotlib.set_loglevel('error') # 只显示错误信息
|
||||||
|
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if y_pred is None or y_true is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建pic文件夹
|
||||||
|
pic_dir = os.path.join(save_dir, 'pic')
|
||||||
|
os.makedirs(pic_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 固定生成10张图片
|
||||||
|
num_nodes_to_plot = 10
|
||||||
|
|
||||||
|
# 生成单个节点的详细图
|
||||||
|
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
|
||||||
|
for node_id in range(num_nodes_to_plot):
|
||||||
|
# 获取对应节点的数据
|
||||||
|
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||||
|
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||||
|
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻,指定节点,第一个特征
|
||||||
|
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
|
||||||
|
else:
|
||||||
|
# 如果数据不足10个节点,只处理实际存在的节点
|
||||||
|
if node_id >= y_pred.shape[-2]:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 取前500个时间步
|
||||||
|
max_steps = min(500, len(node_pred))
|
||||||
|
if max_steps <= 0:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_pred_500 = node_pred[:max_steps]
|
||||||
|
node_true_500 = node_true[:max_steps]
|
||||||
|
|
||||||
|
# 创建时间轴
|
||||||
|
time_steps = np.arange(max_steps)
|
||||||
|
|
||||||
|
# 绘制对比图
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
|
||||||
|
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
|
||||||
|
plt.xlabel('Time Steps')
|
||||||
|
plt.ylabel('Values')
|
||||||
|
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 保存图片,使用不同的命名
|
||||||
|
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 生成所有节点的对比图(前100个时间步,便于观察)
|
||||||
|
# 选择前100个时间步
|
||||||
|
plot_steps = min(100, y_pred.shape[0])
|
||||||
|
if plot_steps <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建子图
|
||||||
|
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for node_id in range(num_nodes_to_plot):
|
||||||
|
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||||
|
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||||
|
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
else:
|
||||||
|
# 如果数据不足10个节点,只处理实际存在的节点
|
||||||
|
if node_id >= y_pred.shape[-2]:
|
||||||
|
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||||
|
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
time_steps = np.arange(plot_steps)
|
||||||
|
|
||||||
|
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
|
||||||
|
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
|
||||||
|
axes[node_id].set_title(f'Node {node_id + 1}')
|
||||||
|
axes[node_id].grid(True, alpha=0.3)
|
||||||
|
axes[node_id].legend(fontsize=8)
|
||||||
|
|
||||||
|
if node_id >= 5: # 下面一行添加x轴标签
|
||||||
|
axes[node_id].set_xlabel('Time Steps')
|
||||||
|
if node_id % 5 == 0: # 左边一列添加y轴标签
|
||||||
|
axes[node_id].set_ylabel('Values')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
|
||||||
|
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||||
|
target_node=1, num_samples=10, scalers=None):
|
||||||
|
"""
|
||||||
|
生成输入-输出样本比较图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred: 预测值
|
||||||
|
y_true: 真实值
|
||||||
|
data_loader: 数据加载器,用于获取输入数据
|
||||||
|
logger: 日志记录器
|
||||||
|
save_dir: 保存目录
|
||||||
|
target_node: 目标节点ID(从1开始)
|
||||||
|
num_samples: 要比较的样本数量
|
||||||
|
scalers: 标准化器列表,用于反归一化输入数据
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 设置matplotlib配置
|
||||||
|
matplotlib.set_loglevel('error')
|
||||||
|
plt.rcParams['font.family'] = 'DejaVu Sans'
|
||||||
|
|
||||||
|
# 创建compare文件夹
|
||||||
|
compare_dir = os.path.join(save_dir, 'pic', 'compare')
|
||||||
|
os.makedirs(compare_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取输入数据
|
||||||
|
input_data = []
|
||||||
|
for batch_idx, (data, target) in enumerate(data_loader):
|
||||||
|
if batch_idx >= num_samples:
|
||||||
|
break
|
||||||
|
input_data.append(data.cpu().numpy())
|
||||||
|
|
||||||
|
if not input_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取目标节点的索引(从0开始)
|
||||||
|
node_idx = target_node - 1
|
||||||
|
|
||||||
|
# 检查节点索引是否有效
|
||||||
|
if node_idx >= y_pred.shape[-2]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 为每个样本生成比较图
|
||||||
|
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
|
||||||
|
for sample_idx in range(min(num_samples, len(input_data))):
|
||||||
|
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features])
|
||||||
|
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch,所有时间步,目标节点,第一个特征
|
||||||
|
|
||||||
|
# 对输入数据进行反归一化
|
||||||
|
if scalers is not None and len(scalers) > 0:
|
||||||
|
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
|
||||||
|
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||||
|
|
||||||
|
# 获取对应的预测值和真实值
|
||||||
|
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon,目标节点,第一个特征
|
||||||
|
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建时间轴 - 输入和输出连续
|
||||||
|
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||||
|
|
||||||
|
# 创建合并的图形 - 输入和输出在同一个图中
|
||||||
|
plt.figure(figsize=(14, 8))
|
||||||
|
|
||||||
|
# 绘制完整的真实值曲线(输入 + 真实输出)
|
||||||
|
true_combined = np.concatenate([input_seq, true_seq])
|
||||||
|
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
|
||||||
|
linewidth=2.5, alpha=0.9, linestyle='-')
|
||||||
|
|
||||||
|
# 绘制预测值曲线(只绘制输出部分)
|
||||||
|
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||||
|
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
|
||||||
|
linewidth=2, alpha=0.8, linestyle='-')
|
||||||
|
|
||||||
|
# 添加垂直线分隔输入和输出
|
||||||
|
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
|
||||||
|
label='Input/Output Boundary')
|
||||||
|
|
||||||
|
# 设置图形属性
|
||||||
|
plt.xlabel('Time Steps')
|
||||||
|
plt.ylabel('Values')
|
||||||
|
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 调整布局
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图片
|
||||||
|
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 生成汇总图(所有样本的预测值对比)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for sample_idx in range(min(num_samples, len(input_data))):
|
||||||
|
if sample_idx >= 10: # 最多显示10个子图
|
||||||
|
break
|
||||||
|
|
||||||
|
ax = axes[sample_idx]
|
||||||
|
|
||||||
|
# 获取输入序列和预测值、真实值
|
||||||
|
input_seq = input_data[sample_idx][0, :, node_idx, 0]
|
||||||
|
if scalers is not None and len(scalers) > 0:
|
||||||
|
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||||
|
|
||||||
|
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
|
||||||
|
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=ax.transAxes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 绘制对比图 - 输入和输出连续显示
|
||||||
|
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||||
|
true_combined = np.concatenate([input_seq, true_seq])
|
||||||
|
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||||
|
|
||||||
|
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
|
||||||
|
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
|
||||||
|
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
|
||||||
|
ax.set_title(f'Sample {sample_idx + 1}')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(fontsize=8)
|
||||||
|
|
||||||
|
if sample_idx >= 5: # 下面一行添加x轴标签
|
||||||
|
ax.set_xlabel('Time Steps')
|
||||||
|
if sample_idx % 5 == 0: # 左边一列添加y轴标签
|
||||||
|
ax.set_ylabel('Values')
|
||||||
|
|
||||||
|
# 隐藏多余的子图
|
||||||
|
for i in range(min(num_samples, len(input_data)), 10):
|
||||||
|
axes[i].set_visible(False)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
|
||||||
|
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def _save_nfe_data(self):
|
||||||
|
"""保存nfe数据到文件"""
|
||||||
|
if not self.res:
|
||||||
|
return
|
||||||
|
|
||||||
|
res = pd.concat(self.res, keys=self.keys)
|
||||||
|
res.index.names = ['epoch', 'iter']
|
||||||
|
|
||||||
|
# 获取模型配置参数
|
||||||
|
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
||||||
|
atol = getattr(self.model, 'atol', 1e-5)
|
||||||
|
rtol = getattr(self.model, 'rtol', 1e-5)
|
||||||
|
|
||||||
|
# 保存nfe数据
|
||||||
|
nfe_file = os.path.join(
|
||||||
|
self.logger.dir_path,
|
||||||
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||||
|
res.to_pickle(nfe_file)
|
||||||
|
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_sampling_threshold(global_step, k):
|
||||||
|
return k / (k + math.exp(global_step / k))
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
from trainer.trainer import Trainer
|
from trainer.trainer import Trainer
|
||||||
|
from trainer.ode_trainer import Trainer as ode_trainer
|
||||||
|
|
||||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||||
lr_scheduler, kwargs):
|
lr_scheduler, kwargs):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
selected_Trainer = None
|
selected_Trainer = None
|
||||||
match model_name:
|
match model_name:
|
||||||
|
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
||||||
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||||
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
if selected_Trainer is None: raise NotImplementedError
|
if selected_Trainer is None: raise NotImplementedError
|
||||||
return selected_Trainer
|
return selected_Trainer
|
||||||
Loading…
Reference in New Issue