Compare commits

..

11 Commits
STGODE ... main

22 changed files with 884 additions and 295 deletions

8
.idea/.gitignore vendored
View File

@ -1,8 +0,0 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

View File

@ -1,12 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -1,27 +0,0 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="136" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="8">
<item index="0" class="java.lang.String" itemvalue="argparse" />
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
<item index="5" class="java.lang.String" itemvalue="setuptools" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -1,6 +0,0 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
</project>

View File

@ -1,8 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
</modules>
</component>
</project>

View File

@ -1,7 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

View File

@ -1,216 +0,0 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AutoImportSettings">
<option name="autoReloadType" value="SELECTIVE" />
</component>
<component name="ChangeListManager">
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
</list>
<option name="SHOW_DIALOG" value="false" />
<option name="HIGHLIGHT_CONFLICTS" value="true" />
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
<option name="LAST_RESOLUTION" value="IGNORE" />
</component>
<component name="FileTemplateManagerImpl">
<option name="RECENT_TEMPLATES">
<list>
<option value="Python Script" />
</list>
</option>
</component>
<component name="Git.Settings">
<excluded-from-favorite>
<branch-storage>
<map>
<entry type="LOCAL">
<value>
<list>
<branch-info repo="$PROJECT_DIR$" source="main" />
</list>
</value>
</entry>
</map>
</branch-storage>
</excluded-from-favorite>
<option name="RECENT_BRANCH_BY_REPOSITORY">
<map>
<entry key="$PROJECT_DIR$" value="main" />
</map>
</option>
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
<option name="ROOT_SYNC" value="DONT_SYNC" />
</component>
<component name="MarkdownSettingsMigration">
<option name="stateVersion" value="1" />
</component>
<component name="ProjectColorInfo">{
&quot;associatedIndex&quot;: 3
}</component>
<component name="ProjectId" id="3264JlB7seHXuXCCcdmTyEsXI45" />
<component name="ProjectViewState">
<option name="hideEmptyMiddlePackages" value="true" />
<option name="showLibraryContents" value="true" />
</component>
<component name="PropertiesComponent"><![CDATA[{
"keyToString": {
"Python.STDEN.executor": "Debug",
"Python.STGODE.executor": "Run",
"Python.main.executor": "Run",
"RunOnceActivity.OpenProjectViewOnStart": "true",
"RunOnceActivity.ShowReadmeOnStart": "true",
"git-widget-placeholder": "STGODE",
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
"node.js.detected.package.eslint": "true",
"node.js.detected.package.tslint": "true",
"node.js.selected.package.eslint": "(autodetect)",
"node.js.selected.package.tslint": "(autodetect)",
"nodejs_package_manager_path": "npm",
"vue.rearranger.settings.migration": "true"
}
}]]></component>
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
<layout>
<window_info id="Space Code Reviews" show_stripe_button="false" />
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
<window_info id="Merge Requests" show_stripe_button="false" />
<window_info id="Commit_Guest" show_stripe_button="false" />
<window_info id="Pull Requests" show_stripe_button="false" />
<window_info id="Learn" show_stripe_button="false" />
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
<window_info id="Commit" order="1" weight="0.25" />
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
<window_info anchor="bottom" id="Version Control" order="0" />
<window_info anchor="bottom" id="Problems" order="1" />
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
<window_info anchor="bottom" id="Services" order="4" />
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
<window_info anchor="right" id="SciView" show_stripe_button="false" />
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
<window_info anchor="right" id="Database" order="2" weight="0.25" />
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
<window_info anchor="right" id="Plots" order="5" weight="0.1" />
</layout>
</component>
<component name="RecentsManager">
<key name="CopyFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/trainer" />
<recent name="$PROJECT_DIR$/configs/STDEN" />
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
<key name="MoveFile.RECENT_KEYS">
<recent name="$PROJECT_DIR$/models/STDEN" />
</key>
</component>
<component name="RunManager" selected="Python.STGODE">
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STDEN/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
<module name="Project-I" />
<option name="ENV_FILES" value="" />
<option name="INTERPRETER_OPTIONS" value="" />
<option name="PARENT_ENVS" value="true" />
<envs>
<env name="PYTHONUNBUFFERED" value="1" />
</envs>
<option name="SDK_HOME" value="" />
<option name="SDK_NAME" value="TS" />
<option name="WORKING_DIRECTORY" value="" />
<option name="IS_MODULE_SDK" value="false" />
<option name="ADD_CONTENT_ROOTS" value="true" />
<option name="ADD_SOURCE_ROOTS" value="true" />
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
<option name="SHOW_COMMAND_LINE" value="false" />
<option name="EMULATE_TERMINAL" value="false" />
<option name="MODULE_MODE" value="false" />
<option name="REDIRECT_INPUT" value="false" />
<option name="INPUT_FILE" value="" />
<method v="2" />
</configuration>
<list>
<item itemvalue="Python.STDEN" />
<item itemvalue="Python.STGODE" />
</list>
</component>
<component name="SharedIndexes">
<attachedChunks>
<set>
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
</set>
</attachedChunks>
</component>
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
<component name="TaskManager">
<task active="true" id="Default" summary="默认任务">
<changelist id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="" />
<created>1756727620810</created>
<option name="number" value="Default" />
<option name="presentableId" value="Default" />
<updated>1756727620810</updated>
<workItem from="1756727623101" duration="4721000" />
<workItem from="1756856673845" duration="652000" />
<workItem from="1756864144998" duration="1063000" />
</task>
<servers />
</component>
<component name="TypeScriptGeneratedFilesManager">
<option name="version" value="3" />
</component>
<component name="XDebuggerManager">
<breakpoint-manager>
<breakpoints>
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
<url>file://$PROJECT_DIR$/models/STDEN/stden_model.py</url>
<line>131</line>
<option name="timeStamp" value="5" />
</line-breakpoint>
</breakpoints>
</breakpoint-manager>
</component>
<component name="com.intellij.coverage.CoverageDataManagerImpl">
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
</component>
</project>

View File

@ -1,3 +1,35 @@
# Project-I # Project-I
Secret Projct Secret Projct
```
mkdir -p models/gpt2
```
## Prepare Env.
```
pip install -r requirement.txt
```
## Download dataset
```
python utils/download.py
```
## Download gpt weight
`mkdir -p models/gpt2`
Download config.json & pytorch_model.bin from https://huggingface.co/openai-community/gpt2/tree/main
```bash
wget https://huggingface.co/openai-community/gpt2/resolve/main/config.json?download=true -O ./models/gpt2/config.json
wget https://huggingface.co/openai-community/gpt2/resolve/main/pytorch_model.bin?download=true -O ./models/gpt2/config.json
```
Use pytorch >= 2.6 to load model.
## Run
```
python main.py --config configs/STGODE_LLM_GPT2/PEMS08.yaml
```

1
STDEN

@ -1 +0,0 @@
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2

View File

@ -0,0 +1,65 @@
basic:
device: cuda:0
dataset: PEMS08
model: STGODE-LLM
mode: test
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 32
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 64
input_dim: 1
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
rnn_units: 64
sigma1: 0.1
sigma2: 10
thres1: 0.6
thres2: 0.5
# LLM backbone settings
llm_hidden: 128
llm_layers: 4
llm_heads: 4
llm_pretrained: True
train:
loss: mae
batch_size: 64
epochs: 100
lr_init: 0.003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
log_step: 3000

View File

@ -0,0 +1,66 @@
basic:
device: cuda:0
dataset: PEMS08
model: STGODE-LLM-GPT2
mode: train
seed: 2025
data:
dataset_dir: data/PEMS08
val_batch_size: 16
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
num_nodes: 170
batch_size: 32
input_dim: 1
lag: 12
horizon: 12
val_ratio: 0.2
test_ratio: 0.2
tod: False
normalizer: std
column_wise: False
default_graph: True
add_time_in_day: True
add_day_in_week: True
steps_per_day: 24
days_per_week: 7
model:
input_dim: 1
output_dim: 1
history: 12
horizon: 12
num_features: 1
rnn_units: 64
sigma1: 0.1
sigma2: 10
thres1: 0.6
thres2: 0.5
# HF GPT-2 settings
gpt2_name: gpt2
gpt2_grad_ckpt: True
gpt2_freeze: True
gpt2_local_dir: ./models/gpt2
train:
loss: mae
batch_size: 32
epochs: 100
lr_init: 0.0003
mape_thresh: 0.001
mae_thresh: None
debug: False
output_dim: 1
weight_decay: 0
lr_decay: False
lr_decay_rate: 0.3
lr_decay_step: "10,30,60,90"
early_stop: True
early_stop_patience: 15
grad_norm: False
max_grad_norm: 5
real_value: True
log_step: 3000

220
data/get_adj.py Normal file
View File

@ -0,0 +1,220 @@
import csv
import os
import numpy as np
import pandas as pd
from scipy.sparse import coo_matrix
from scipy.sparse.linalg import norm
import scipy.sparse as sp
import torch
def get_adj(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
return adj
def get_gso(args):
dataset_path = './data'
match args['num_nodes']:
case 358:
dataset_name = 'PEMS03'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS03.csv')
id = os.path.join(dataset_path, dataset_name, 'PEMS03.txt')
A, adj = load_adj(args['num_nodes'], adj_path, id_filename=id)
case 307:
dataset_name = 'PEMS04'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS04.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
case 883:
dataset_name = 'PEMS07'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS07.csv')
A, adj = load_adj(args['num_nodes'], adj_path)
case 170:
dataset_name = 'PEMS08'
adj_path = os.path.join(dataset_path, dataset_name, 'PEMS08.csv')
A, adj = load_adj(args['num_nodes'], adj_path, std=True)
gso = calc_gso(adj, args['gso_type'])
if args['graph_conv_type'] == 'cheb_graph_conv':
gso = calc_chebynet_gso(gso)
gso = gso.toarray()
gso = gso.astype(dtype=np.float32)
gso = torch.from_numpy(gso).to(args['device'])
return gso
def load_adj(num_nodes, adj_path, id_filename=None, std=False):
'''
Parameters
----------
adj_path: str, path of the csv file contains edges information
num_nodes: int, the number of vertices
id_filename: str, optional, path of the file containing node IDs (if not starting from 0)
std: bool, if True, normalize the cost values in the CSV file using Gaussian normalization
Returns
----------
A: np.ndarray, adjacency matrix
distanceA: np.ndarray, distance matrix (normalized if std=True)
'''
if 'npy' in adj_path:
adj_mx = np.load(adj_path)
return adj_mx, None
else:
A = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
distanceA = np.zeros((int(num_nodes), int(num_nodes)), dtype=np.float32)
# 如果提供了id_filename说明节点ID不是从0开始的需要重新映射
if id_filename:
with open(id_filename, 'r') as f:
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))}
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[id_dict[i], id_dict[j]] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[id_dict[i], id_dict[j]] = distance
else: # 如果没有提供id_filename说明节点ID是从0开始的
with open(adj_path, 'r') as f:
f.readline() # 略过表头那一行
reader = csv.reader(f)
costs = [] # 用于收集所有cost值
for row in reader:
if len(row) != 3:
continue
i, j, distance = int(row[0]), int(row[1]), float(row[2])
A[i, j] = 1
# 确保距离值为正
distance = max(distance, 1e-6)
costs.append(distance) # 收集cost值
distanceA[i, j] = distance
# 如果std=True对CSV中的所有cost值进行高斯正态分布标准化
if std:
mean_cost = np.mean(costs) # 计算cost值的均值
std_cost = np.std(costs) # 计算cost值的标准差
for idx in np.ndindex(distanceA.shape): # 遍历矩阵
if distanceA[idx] > 0: # 只对非零元素进行标准化
normalized_value = (distanceA[idx] - mean_cost) / std_cost
# 确保标准化后的值为正
normalized_value = max(normalized_value, 1e-6)
distanceA[idx] = normalized_value
# 确保矩阵中没有零行
row_sums = distanceA.sum(axis=1)
zero_rows = np.where(row_sums == 0)[0]
for row in zero_rows:
distanceA[row, :] = 1e-6 # 将零行替换为一个非零的默认值
return A, distanceA
def calc_gso(dir_adj, gso_type):
n_vertex = dir_adj.shape[0]
if not sp.issparse(dir_adj):
dir_adj = sp.csc_matrix(dir_adj)
elif dir_adj.format != 'csc':
dir_adj = dir_adj.tocsc()
id = sp.identity(n_vertex, format='csc')
# Symmetrizing an adjacency matrix
adj = dir_adj + dir_adj.T.multiply(dir_adj.T > dir_adj) - dir_adj.multiply(dir_adj.T > dir_adj)
# adj = 0.5 * (dir_adj + dir_adj.transpose())
if gso_type in ['sym_renorm_adj', 'rw_renorm_adj', 'sym_renorm_lap', 'rw_renorm_lap']:
adj = adj + id
if gso_type in ['sym_norm_adj', 'sym_renorm_adj', 'sym_norm_lap', 'sym_renorm_lap']:
row_sum = adj.sum(axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for symmetric normalization.")
row_sum_inv_sqrt = np.power(row_sum, -0.5)
row_sum_inv_sqrt[np.isinf(row_sum_inv_sqrt)] = 0. # Handle inf values
deg_inv_sqrt = sp.diags(row_sum_inv_sqrt, format='csc')
# A_{sym} = D^{-0.5} * A * D^{-0.5}
sym_norm_adj = deg_inv_sqrt.dot(adj).dot(deg_inv_sqrt)
if gso_type in ['sym_norm_lap', 'sym_renorm_lap']:
sym_norm_lap = id - sym_norm_adj
gso = sym_norm_lap
else:
gso = sym_norm_adj
elif gso_type in ['rw_norm_adj', 'rw_renorm_adj', 'rw_norm_lap', 'rw_renorm_lap']:
row_sum = np.sum(adj, axis=1).A1
# Check for zero or negative values in row_sum
if np.any(row_sum <= 0):
raise ValueError(
"Row sum contains zero or negative values, which is not allowed for random walk normalization.")
row_sum_inv = np.power(row_sum, -1)
row_sum_inv[np.isinf(row_sum_inv)] = 0. # Handle inf values
deg_inv = sp.diags(row_sum_inv, format='csc')
# A_{rw} = D^{-1} * A
rw_norm_adj = deg_inv.dot(adj)
if gso_type in ['rw_norm_lap', 'rw_renorm_lap']:
rw_norm_lap = id - rw_norm_adj
gso = rw_norm_lap
else:
gso = rw_norm_adj
else:
raise ValueError(f'{gso_type} is not defined.')
# Check for nan or inf in the final result
if np.isnan(gso.data).any() or np.isinf(gso.data).any():
raise ValueError("NaN or Inf detected in the final GSO matrix. Please check the input adjacency matrix.")
return gso
def calc_chebynet_gso(gso):
if sp.issparse(gso) == False:
gso = sp.csc_matrix(gso)
elif gso.format != 'csc':
gso = gso.tocsc()
id = sp.identity(gso.shape[0], format='csc')
# If you encounter a NotImplementedError, please update your scipy version to 1.10.1 or later.
eigval_max = norm(gso, 2)
# If the gso is symmetric or random walk normalized Laplacian,
# then the maximum eigenvalue is smaller than or equals to 2.
if eigval_max >= 2:
gso = gso - id
else:
gso = 2 * gso / eigval_max - id
return gso

View File

@ -0,0 +1,152 @@
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
from models.STGODE.odegcn import ODEG
from models.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class GPT2Backbone(nn.Module):
def __init__(self, hidden_size: int, n_layer: int = 4, n_head: int = 4, n_embd: int | None = None, use_pretrained: bool = True):
super().__init__()
self.hidden_size = hidden_size
self.use_transformers = False
self.model = None
if n_embd is None:
n_embd = hidden_size
if use_pretrained:
try:
from transformers import GPT2Model, GPT2Config
config = GPT2Config(n_embd=n_embd, n_layer=n_layer, n_head=n_head, n_positions=1024, n_ctx=1024, vocab_size=1)
self.model = GPT2Model(config)
self.use_transformers = True
except Exception:
self.use_transformers = False
if not self.use_transformers:
encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=n_head, batch_first=True)
self.model = nn.TransformerEncoder(encoder_layer, num_layers=n_layer)
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
if self.use_transformers:
outputs = self.model(inputs_embeds=inputs_embeds)
return outputs.last_hidden_state
else:
return self.model(inputs_embeds)
class ODEGCN_LLM(nn.Module):
def __init__(self, config):
super(ODEGCN_LLM, self).__init__()
args = config['model']
num_nodes = config['data']['num_nodes']
num_features = args['num_features']
num_timesteps_input = args['history']
num_timesteps_output = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(config)
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
self.se_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
self.history = num_timesteps_input
self.horizon = num_timesteps_output
hidden_size = int(args.get('llm_hidden', 128))
llm_layers = int(args.get('llm_layers', 4))
llm_heads = int(args.get('llm_heads', 4))
use_pretrained = bool(args.get('llm_pretrained', True))
self.to_llm_embed = nn.Linear(64, hidden_size)
self.gpt2 = GPT2Backbone(hidden_size=hidden_size, n_layer=llm_layers, n_head=llm_heads, use_pretrained=use_pretrained)
self.proj_head = nn.Sequential(
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, self.horizon)
)
def forward(self, x):
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
for blk in self.sp_blocks:
outs.append(blk(x))
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0]
# x: (B, N, T, 64) physical quantities after ODE-based transform
B, N, T, C = x.shape
x = self.to_llm_embed(x) # (B, N, T, H)
x = x.permute(0, 1, 2, 3).contiguous().view(B * N, T, -1) # (B*N, T, H)
llm_hidden = self.gpt2(inputs_embeds=x) # (B*N, T, H)
last_state = llm_hidden[:, -1, :] # (B*N, H)
y = self.proj_head(last_state) # (B*N, horizon)
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1) # (B, horizon, N, 1)
return y

View File

@ -0,0 +1,4 @@
from .STGODE_LLM import ODEGCN_LLM

View File

@ -0,0 +1,145 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.STGODE.odegcn import ODEG
from models.STGODE.adj import get_A_hat
class Chomp1d(nn.Module):
def __init__(self, chomp_size):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x):
return x[:, :, :, :-self.chomp_size].contiguous()
class TemporalConvNet(nn.Module):
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
super(TemporalConvNet, self).__init__()
layers = []
num_levels = len(num_channels)
for i in range(num_levels):
dilation_size = 2 ** i
in_channels = num_inputs if i == 0 else num_channels[i - 1]
out_channels = num_channels[i]
padding = (kernel_size - 1) * dilation_size
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
padding=(0, padding))
self.conv.weight.data.normal_(0, 0.01)
self.chomp = Chomp1d(padding)
self.relu = nn.ReLU()
self.dropout = nn.Dropout(dropout)
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
self.network = nn.Sequential(*layers)
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
if self.downsample:
self.downsample.weight.data.normal_(0, 0.01)
def forward(self, x):
y = x.permute(0, 3, 1, 2)
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
y = y.permute(0, 2, 3, 1)
return y
class STGCNBlock(nn.Module):
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
super(STGCNBlock, self).__init__()
self.A_hat = A_hat
self.temporal1 = TemporalConvNet(num_inputs=in_channels, num_channels=out_channels)
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1], num_channels=out_channels)
self.batch_norm = nn.BatchNorm2d(num_nodes)
def forward(self, X):
t = self.temporal1(X)
t = self.odeg(t)
t = self.temporal2(F.relu(t))
return self.batch_norm(t)
class GPT2BackboneHF(nn.Module):
def __init__(self, model_name: str | None = None, gradient_checkpointing: bool = False, freeze: bool = False, local_dir: str | None = None):
super().__init__()
from transformers import GPT2Model
if local_dir is not None and len(local_dir) > 0:
self.model = GPT2Model.from_pretrained(local_dir, local_files_only=True, use_cache=False)
else:
if model_name is None:
model_name = 'gpt2'
self.model = GPT2Model.from_pretrained(model_name, use_cache=False)
if gradient_checkpointing:
self.model.gradient_checkpointing_enable()
self.hidden_size = self.model.config.hidden_size
if freeze:
for p in self.model.parameters():
p.requires_grad = False
def forward(self, inputs_embeds: torch.Tensor) -> torch.Tensor:
outputs = self.model(inputs_embeds=inputs_embeds)
return outputs.last_hidden_state
class ODEGCN_LLM_GPT2(nn.Module):
def __init__(self, config):
super(ODEGCN_LLM_GPT2, self).__init__()
args = config['model']
num_nodes = config['data']['num_nodes']
num_features = args['num_features']
self.history = args['history']
self.horizon = args['horizon']
A_sp_hat, A_se_hat = get_A_hat(config)
self.sp_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
])
self.se_blocks = nn.ModuleList(
[nn.Sequential(
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat),
STGCNBlock(in_channels=64, out_channels=[64, 32, 64], num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
])
# HF GPT-2
gpt2_name = args.get('gpt2_name', 'gpt2')
grad_ckpt = bool(args.get('gpt2_grad_ckpt', False))
gpt2_freeze = bool(args.get('gpt2_freeze', False))
gpt2_local_dir = args.get('gpt2_local_dir', None)
self.gpt2 = GPT2BackboneHF(gpt2_name, gradient_checkpointing=grad_ckpt, freeze=gpt2_freeze, local_dir=gpt2_local_dir)
# Project ODE features to GPT-2 hidden size
self.to_llm_embed = nn.Linear(64, self.gpt2.hidden_size)
# Prediction head
self.proj_head = nn.Sequential(
nn.Linear(self.gpt2.hidden_size, self.gpt2.hidden_size),
nn.ReLU(),
nn.Linear(self.gpt2.hidden_size, self.horizon)
)
def forward(self, x):
x = x[..., 0:1].permute(0, 2, 1, 3)
outs = []
for blk in self.sp_blocks:
outs.append(blk(x))
for blk in self.se_blocks:
outs.append(blk(x))
outs = torch.stack(outs)
x = torch.max(outs, dim=0)[0] # (B, N, T, 64)
B, N, T, C = x.shape
x = self.to_llm_embed(x).view(B * N, T, -1)
llm_hidden = self.gpt2(inputs_embeds=x)
last_state = llm_hidden[:, -1, :]
y = self.proj_head(last_state)
y = y.view(B, N, self.horizon).permute(0, 2, 1).unsqueeze(-1)
return y

View File

@ -0,0 +1,4 @@
from .STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2

View File

@ -1,5 +1,7 @@
from models.STDEN.stden_model import STDENModel from models.STDEN.stden_model import STDENModel
from models.STGODE.STGODE import ODEGCN from models.STGODE.STGODE import ODEGCN
from models.STGODE_LLM_GPT2.STGODE_LLM_GPT2 import ODEGCN_LLM_GPT2
def model_selector(config): def model_selector(config):
model_name = config['basic']['model'] model_name = config['basic']['model']
@ -9,4 +11,6 @@ def model_selector(config):
model = STDENModel(config) model = STDENModel(config)
case 'STGODE': case 'STGODE':
model = ODEGCN(config) model = ODEGCN(config)
case 'STGODE-LLM-GPT2':
model = ODEGCN_LLM_GPT2(config)
return model return model

Binary file not shown.

43
requirements.txt Normal file
View File

@ -0,0 +1,43 @@
# 核心深度学习框架
torch
torchvision
torchaudio
# 科学计算和数据处理
numpy
pandas
scipy
# 机器学习工具
scikit-learn
# 配置和文件处理
pyyaml
# 进度条
tqdm
# 图神经网络和距离计算
fastdtw
# 微分方程求解器
torchdiffeq
# 自然语言处理用于GPT-2模型
transformers
# 数据可视化
matplotlib
# 网络请求(用于数据下载)
requests
# 文件压缩处理
# Kaggle数据下载
kagglehub
# 其他工具
future

146
utils/download.py Normal file
View File

@ -0,0 +1,146 @@
import os
import requests
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
from tqdm import tqdm
# 定义文件完整性信息的字典
def check_and_download_data():
"""
检查 data 文件夹的完整性并根据缺失文件类型下载相应数据
"""
current_working_dir = os.getcwd() # 获取当前工作目录
data_dir = os.path.join(current_working_dir, "data") # 假设 data 文件夹在当前工作目录下
expected_structure = {
"PEMS03": ["PEMS03.csv", "PEMS03.npz", "PEMS03.txt", "PEMS03_dtw_distance.npy", "PEMS03_spatial_distance.npy"],
"PEMS04": ["PEMS04.csv", "PEMS04.npz", "PEMS04_dtw_distance.npy", "PEMS04_spatial_distance.npy"],
"PEMS07": ["PEMS07.csv", "PEMS07.npz", "PEMS07_dtw_distance.npy", "PEMS07_spatial_distance.npy"],
"PEMS08": ["PEMS08.csv", "PEMS08.npz", "PEMS08_dtw_distance.npy", "PEMS08_spatial_distance.npy"]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 检查根目录下的 get_adj.py 文件
if "get_adj.py" not in os.listdir(data_dir):
# print(f"根目录下缺少文件 get_adj.py。")
missing_adj = True
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue
# 获取子文件夹中的实际文件列表
actual_files = os.listdir(subfolder_path)
# 检查是否缺少文件
for expected_file in expected_files:
if expected_file not in actual_files:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
if "_dtw_distance.npy" in expected_file or "_spatial_distance.npy" in expected_file:
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑
if missing_adj:
download_adj_data(current_dir)
if missing_main_files:
download_kaggle_data(current_dir)
return True
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件并显示下载进度条
如果下载失败最多重试 max_retries
"""
url = "http://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit='B', unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, 'wb') as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。")
def download_kaggle_data(current_dir):
"""
下载 KaggleHub 数据集并将 data 文件夹合并到当前工作目录
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print("正在下载 PEMS 数据集...")
path = kagglehub.dataset_download("elmahy/pems-dataset")
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
data_folder_path = os.path.join(path, "data")
if os.path.exists(data_folder_path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 合并文件夹,覆盖冲突的文件
shutil.copytree(data_folder_path, destination_path, dirs_exist_ok=True)
# print(f"data 文件夹已合并到: {destination_path}")
# else:
# print("未找到 data 文件夹,跳过合并操作。")
# else:
# print("未找到 KaggleHub 数据集路径,跳过处理。")
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
# 主程序
if __name__ == "__main__":
check_and_download_data()