更新STGODE模型:添加配置文件、优化模型代码、新增测试数据文件
This commit is contained in:
parent
56a7f309fa
commit
387f64efab
|
|
@ -5,9 +5,14 @@
|
||||||
</component>
|
</component>
|
||||||
<component name="ChangeListManager">
|
<component name="ChangeListManager">
|
||||||
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
||||||
|
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
||||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
|
||||||
|
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
|
||||||
</list>
|
</list>
|
||||||
<option name="SHOW_DIALOG" value="false" />
|
<option name="SHOW_DIALOG" value="false" />
|
||||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||||
|
|
@ -40,7 +45,7 @@
|
||||||
<entry key="$PROJECT_DIR$" value="main" />
|
<entry key="$PROJECT_DIR$" value="main" />
|
||||||
</map>
|
</map>
|
||||||
</option>
|
</option>
|
||||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$/STDEN" />
|
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||||
<option name="ROOT_SYNC" value="DONT_SYNC" />
|
<option name="ROOT_SYNC" value="DONT_SYNC" />
|
||||||
</component>
|
</component>
|
||||||
<component name="MarkdownSettingsMigration">
|
<component name="MarkdownSettingsMigration">
|
||||||
|
|
@ -57,10 +62,11 @@
|
||||||
<component name="PropertiesComponent"><![CDATA[{
|
<component name="PropertiesComponent"><![CDATA[{
|
||||||
"keyToString": {
|
"keyToString": {
|
||||||
"Python.STDEN.executor": "Debug",
|
"Python.STDEN.executor": "Debug",
|
||||||
|
"Python.STGODE.executor": "Run",
|
||||||
"Python.main.executor": "Run",
|
"Python.main.executor": "Run",
|
||||||
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||||
"git-widget-placeholder": "main",
|
"git-widget-placeholder": "STGODE",
|
||||||
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
||||||
"node.js.detected.package.eslint": "true",
|
"node.js.detected.package.eslint": "true",
|
||||||
"node.js.detected.package.tslint": "true",
|
"node.js.detected.package.tslint": "true",
|
||||||
|
|
@ -83,14 +89,14 @@
|
||||||
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
||||||
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
|
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
|
||||||
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
|
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
|
||||||
<window_info active="true" anchor="bottom" id="Debug" visible="true" weight="0.32867557" />
|
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
|
||||||
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
||||||
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
||||||
<window_info anchor="bottom" id="Run" weight="0.32867557" />
|
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
|
||||||
<window_info anchor="bottom" id="Version Control" order="0" />
|
<window_info anchor="bottom" id="Version Control" order="0" />
|
||||||
<window_info anchor="bottom" id="Problems" order="1" />
|
<window_info anchor="bottom" id="Problems" order="1" />
|
||||||
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
||||||
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32867557" />
|
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
|
||||||
<window_info anchor="bottom" id="Services" order="4" />
|
<window_info anchor="bottom" id="Services" order="4" />
|
||||||
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
||||||
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
||||||
|
|
@ -114,7 +120,7 @@
|
||||||
<recent name="$PROJECT_DIR$/models/STDEN" />
|
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||||
</key>
|
</key>
|
||||||
</component>
|
</component>
|
||||||
<component name="RunManager">
|
<component name="RunManager" selected="Python.STGODE">
|
||||||
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
|
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
|
||||||
<module name="Project-I" />
|
<module name="Project-I" />
|
||||||
<option name="ENV_FILES" value="" />
|
<option name="ENV_FILES" value="" />
|
||||||
|
|
@ -139,6 +145,34 @@
|
||||||
<option name="INPUT_FILE" value="" />
|
<option name="INPUT_FILE" value="" />
|
||||||
<method v="2" />
|
<method v="2" />
|
||||||
</configuration>
|
</configuration>
|
||||||
|
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
|
||||||
|
<module name="Project-I" />
|
||||||
|
<option name="ENV_FILES" value="" />
|
||||||
|
<option name="INTERPRETER_OPTIONS" value="" />
|
||||||
|
<option name="PARENT_ENVS" value="true" />
|
||||||
|
<envs>
|
||||||
|
<env name="PYTHONUNBUFFERED" value="1" />
|
||||||
|
</envs>
|
||||||
|
<option name="SDK_HOME" value="" />
|
||||||
|
<option name="SDK_NAME" value="TS" />
|
||||||
|
<option name="WORKING_DIRECTORY" value="" />
|
||||||
|
<option name="IS_MODULE_SDK" value="false" />
|
||||||
|
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||||
|
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||||
|
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||||
|
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||||
|
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
|
||||||
|
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||||
|
<option name="EMULATE_TERMINAL" value="false" />
|
||||||
|
<option name="MODULE_MODE" value="false" />
|
||||||
|
<option name="REDIRECT_INPUT" value="false" />
|
||||||
|
<option name="INPUT_FILE" value="" />
|
||||||
|
<method v="2" />
|
||||||
|
</configuration>
|
||||||
|
<list>
|
||||||
|
<item itemvalue="Python.STDEN" />
|
||||||
|
<item itemvalue="Python.STGODE" />
|
||||||
|
</list>
|
||||||
</component>
|
</component>
|
||||||
<component name="SharedIndexes">
|
<component name="SharedIndexes">
|
||||||
<attachedChunks>
|
<attachedChunks>
|
||||||
|
|
@ -157,6 +191,7 @@
|
||||||
<updated>1756727620810</updated>
|
<updated>1756727620810</updated>
|
||||||
<workItem from="1756727623101" duration="4721000" />
|
<workItem from="1756727623101" duration="4721000" />
|
||||||
<workItem from="1756856673845" duration="652000" />
|
<workItem from="1756856673845" duration="652000" />
|
||||||
|
<workItem from="1756864144998" duration="1063000" />
|
||||||
</task>
|
</task>
|
||||||
<servers />
|
<servers />
|
||||||
</component>
|
</component>
|
||||||
|
|
@ -176,5 +211,6 @@
|
||||||
</component>
|
</component>
|
||||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||||
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||||
|
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||||
</component>
|
</component>
|
||||||
</project>
|
</project>
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
basic:
|
||||||
|
device: cuda:0
|
||||||
|
dataset: PEMS08
|
||||||
|
model: STGODE
|
||||||
|
mode: train
|
||||||
|
seed: 2025
|
||||||
|
|
||||||
|
data:
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
val_batch_size: 32
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
num_nodes: 170
|
||||||
|
batch_size: 64
|
||||||
|
input_dim: 1
|
||||||
|
lag: 12
|
||||||
|
horizon: 12
|
||||||
|
val_ratio: 0.2
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: False
|
||||||
|
normalizer: std
|
||||||
|
column_wise: False
|
||||||
|
default_graph: True
|
||||||
|
add_time_in_day: True
|
||||||
|
add_day_in_week: True
|
||||||
|
steps_per_day: 24
|
||||||
|
days_per_week: 7
|
||||||
|
|
||||||
|
model:
|
||||||
|
input_dim: 1
|
||||||
|
output_dim: 1
|
||||||
|
history: 12
|
||||||
|
horizon: 12
|
||||||
|
num_features: 1
|
||||||
|
rnn_units: 64
|
||||||
|
sigma1: 0.1
|
||||||
|
sigma2: 10
|
||||||
|
thres1: 0.6
|
||||||
|
thres2: 0.5
|
||||||
|
|
||||||
|
|
||||||
|
train:
|
||||||
|
loss: mae
|
||||||
|
batch_size: 64
|
||||||
|
epochs: 100
|
||||||
|
lr_init: 0.003
|
||||||
|
mape_thresh: 0.001
|
||||||
|
mae_thresh: None
|
||||||
|
debug: False
|
||||||
|
output_dim: 1
|
||||||
|
weight_decay: 0
|
||||||
|
lr_decay: False
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
early_stop: True
|
||||||
|
early_stop_patience: 15
|
||||||
|
grad_norm: False
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: True
|
||||||
|
log_step: 3000
|
||||||
|
|
||||||
|
|
@ -117,7 +117,7 @@ class STGCNBlock(nn.Module):
|
||||||
class ODEGCN(nn.Module):
|
class ODEGCN(nn.Module):
|
||||||
""" the overall network framework """
|
""" the overall network framework """
|
||||||
|
|
||||||
def __init__(self, args):
|
def __init__(self, config):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
num_nodes : number of nodes in the graph
|
num_nodes : number of nodes in the graph
|
||||||
|
|
@ -129,11 +129,12 @@ class ODEGCN(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super(ODEGCN, self).__init__()
|
super(ODEGCN, self).__init__()
|
||||||
num_nodes = args['num_nodes']
|
args = config['model']
|
||||||
|
num_nodes = config['data']['num_nodes']
|
||||||
num_features = args['num_features']
|
num_features = args['num_features']
|
||||||
num_timesteps_input = args['history']
|
num_timesteps_input = args['history']
|
||||||
num_timesteps_output = args['horizon']
|
num_timesteps_output = args['horizon']
|
||||||
A_sp_hat, A_se_hat = get_A_hat(args)
|
A_sp_hat, A_se_hat = get_A_hat(config)
|
||||||
|
|
||||||
# spatial graph
|
# spatial graph
|
||||||
self.sp_blocks = nn.ModuleList(
|
self.sp_blocks = nn.ModuleList(
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ files = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_A_hat(args):
|
def get_A_hat(config):
|
||||||
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
|
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|
@ -31,12 +31,13 @@ def get_A_hat(args):
|
||||||
dtw_matrix: array, semantic adjacency matrix
|
dtw_matrix: array, semantic adjacency matrix
|
||||||
sp_matrix: array, spatial adjacency matrix
|
sp_matrix: array, spatial adjacency matrix
|
||||||
"""
|
"""
|
||||||
filepath = './data/'
|
file_path = config['data']['graph_pkl_filename']
|
||||||
num_node = args['num_nodes']
|
filename = config['basic']['dataset']
|
||||||
file = files[num_node]
|
dataset_path = config['data']['dataset_dir']
|
||||||
filename = file[0][:6]
|
args = config['model']
|
||||||
|
|
||||||
data = np.load(filepath + file[0])['data']
|
data = np.load(file_path)
|
||||||
|
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
num_node = data.shape[1]
|
num_node = data.shape[1]
|
||||||
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
|
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||||
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
|
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||||
|
|
@ -72,7 +73,7 @@ def get_A_hat(args):
|
||||||
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
|
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
|
||||||
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
|
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
|
||||||
# 使用 pandas 读取 CSV 文件,跳过标题行
|
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||||
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
|
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||||
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
start = int(id_dict[row[0]])
|
start = int(id_dict[row[0]])
|
||||||
|
|
@ -82,7 +83,7 @@ def get_A_hat(args):
|
||||||
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
||||||
else:
|
else:
|
||||||
# 使用 pandas 读取 CSV 文件,跳过标题行
|
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||||
df = pd.read_csv(filepath + file[1], skiprows=1, header=None)
|
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||||
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||||
for _, row in df.iterrows():
|
for _, row in df.iterrows():
|
||||||
start = int(row[0])
|
start = int(row[0])
|
||||||
|
|
@ -98,7 +99,8 @@ def get_A_hat(args):
|
||||||
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
|
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
|
||||||
sp_matrix[sp_matrix < args['thres2']] = 0
|
sp_matrix[sp_matrix < args['thres2']] = 0
|
||||||
|
|
||||||
return get_normalized_adj(dtw_matrix).to(args['device']), get_normalized_adj(sp_matrix).to(args['device'])
|
return (get_normalized_adj(dtw_matrix).to(config['basic']['device']),
|
||||||
|
get_normalized_adj(sp_matrix).to(config['basic']['device']))
|
||||||
|
|
||||||
|
|
||||||
def get_normalized_adj(A):
|
def get_normalized_adj(A):
|
||||||
|
|
@ -115,16 +117,16 @@ def get_normalized_adj(A):
|
||||||
return torch.from_numpy(A_reg.astype(np.float32))
|
return torch.from_numpy(A_reg.astype(np.float32))
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
if __name__ == '__main__':
|
|
||||||
config = {
|
|
||||||
'sigma1': 0.1,
|
|
||||||
'sigma2': 10,
|
|
||||||
'thres1': 0.6,
|
|
||||||
'thres2': 0.5,
|
|
||||||
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
|
||||||
}
|
|
||||||
|
|
||||||
for nodes in [358, 170, 883]:
|
if __name__ == '__main__':
|
||||||
args = {'num_nodes': nodes, **config}
|
config = {
|
||||||
get_A_hat(args)
|
'sigma1': 0.1,
|
||||||
|
'sigma2': 10,
|
||||||
|
'thres1': 0.6,
|
||||||
|
'thres2': 0.5,
|
||||||
|
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||||
|
}
|
||||||
|
|
||||||
|
for nodes in [358, 170, 883]:
|
||||||
|
args = {'num_nodes': nodes, **config}
|
||||||
|
get_A_hat(args)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,12 @@
|
||||||
from models.STDEN.stden_model import STDENModel
|
from models.STDEN.stden_model import STDENModel
|
||||||
|
from models.STGODE.STGODE import ODEGCN
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
model = None
|
model = None
|
||||||
match model_name:
|
match model_name:
|
||||||
case 'STDEN': model = STDENModel(config)
|
case 'STDEN':
|
||||||
|
model = STDENModel(config)
|
||||||
|
case 'STGODE':
|
||||||
|
model = ODEGCN(config)
|
||||||
return model
|
return model
|
||||||
Binary file not shown.
Binary file not shown.
Loading…
Reference in New Issue