Compare commits
No commits in common. "387f64efab6e951f87a997c2aa5842e27ae94de5" and "a8cc3a20fd44f24c099d27ab8f360be1237df093" have entirely different histories.
387f64efab
...
a8cc3a20fd
|
|
@ -1,216 +0,0 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AutoImportSettings">
|
||||
<option name="autoReloadType" value="SELECTIVE" />
|
||||
</component>
|
||||
<component name="ChangeListManager">
|
||||
<list default="true" id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="">
|
||||
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/lib/utils.py" beforeDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_eval.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_eval.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/STDEN/stden_train.py" beforeDir="false" afterPath="$PROJECT_DIR$/STDEN/stden_train.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/STGODE/STGODE.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/STGODE.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/STGODE/adj.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/STGODE/adj.py" afterDir="false" />
|
||||
<change beforePath="$PROJECT_DIR$/models/model_selector.py" beforeDir="false" afterPath="$PROJECT_DIR$/models/model_selector.py" afterDir="false" />
|
||||
</list>
|
||||
<option name="SHOW_DIALOG" value="false" />
|
||||
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
||||
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
||||
<option name="LAST_RESOLUTION" value="IGNORE" />
|
||||
</component>
|
||||
<component name="FileTemplateManagerImpl">
|
||||
<option name="RECENT_TEMPLATES">
|
||||
<list>
|
||||
<option value="Python Script" />
|
||||
</list>
|
||||
</option>
|
||||
</component>
|
||||
<component name="Git.Settings">
|
||||
<excluded-from-favorite>
|
||||
<branch-storage>
|
||||
<map>
|
||||
<entry type="LOCAL">
|
||||
<value>
|
||||
<list>
|
||||
<branch-info repo="$PROJECT_DIR$" source="main" />
|
||||
</list>
|
||||
</value>
|
||||
</entry>
|
||||
</map>
|
||||
</branch-storage>
|
||||
</excluded-from-favorite>
|
||||
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
||||
<map>
|
||||
<entry key="$PROJECT_DIR$" value="main" />
|
||||
</map>
|
||||
</option>
|
||||
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
||||
<option name="ROOT_SYNC" value="DONT_SYNC" />
|
||||
</component>
|
||||
<component name="MarkdownSettingsMigration">
|
||||
<option name="stateVersion" value="1" />
|
||||
</component>
|
||||
<component name="ProjectColorInfo">{
|
||||
"associatedIndex": 3
|
||||
}</component>
|
||||
<component name="ProjectId" id="3264JlB7seHXuXCCcdmTyEsXI45" />
|
||||
<component name="ProjectViewState">
|
||||
<option name="hideEmptyMiddlePackages" value="true" />
|
||||
<option name="showLibraryContents" value="true" />
|
||||
</component>
|
||||
<component name="PropertiesComponent"><![CDATA[{
|
||||
"keyToString": {
|
||||
"Python.STDEN.executor": "Debug",
|
||||
"Python.STGODE.executor": "Run",
|
||||
"Python.main.executor": "Run",
|
||||
"RunOnceActivity.OpenProjectViewOnStart": "true",
|
||||
"RunOnceActivity.ShowReadmeOnStart": "true",
|
||||
"git-widget-placeholder": "STGODE",
|
||||
"last_opened_file_path": "/home/czzhangheng/code/Project-I/main.py",
|
||||
"node.js.detected.package.eslint": "true",
|
||||
"node.js.detected.package.tslint": "true",
|
||||
"node.js.selected.package.eslint": "(autodetect)",
|
||||
"node.js.selected.package.tslint": "(autodetect)",
|
||||
"nodejs_package_manager_path": "npm",
|
||||
"vue.rearranger.settings.migration": "true"
|
||||
}
|
||||
}]]></component>
|
||||
<component name="RdControllerToolWindowsLayoutState" isNewUi="true">
|
||||
<layout>
|
||||
<window_info id="Space Code Reviews" show_stripe_button="false" />
|
||||
<window_info id="Bookmarks" show_stripe_button="false" side_tool="true" />
|
||||
<window_info id="Merge Requests" show_stripe_button="false" />
|
||||
<window_info id="Commit_Guest" show_stripe_button="false" />
|
||||
<window_info id="Pull Requests" show_stripe_button="false" />
|
||||
<window_info id="Learn" show_stripe_button="false" />
|
||||
<window_info active="true" content_ui="combo" id="Project" order="0" visible="true" weight="0.27326387" />
|
||||
<window_info id="Commit" order="1" weight="0.25" />
|
||||
<window_info id="Structure" order="2" side_tool="true" weight="0.25" />
|
||||
<window_info anchor="bottom" id="Database Changes" show_stripe_button="false" />
|
||||
<window_info anchor="bottom" id="TypeScript" show_stripe_button="false" />
|
||||
<window_info anchor="bottom" id="Debug" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="TODO" show_stripe_button="false" />
|
||||
<window_info anchor="bottom" id="File Transfer" show_stripe_button="false" />
|
||||
<window_info active="true" anchor="bottom" id="Run" visible="true" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Version Control" order="0" />
|
||||
<window_info anchor="bottom" id="Problems" order="1" />
|
||||
<window_info anchor="bottom" id="Problems View" order="2" weight="0.33686176" />
|
||||
<window_info anchor="bottom" id="Terminal" order="3" weight="0.32989067" />
|
||||
<window_info anchor="bottom" id="Services" order="4" />
|
||||
<window_info anchor="bottom" id="Python Packages" order="5" weight="0.1" />
|
||||
<window_info anchor="bottom" id="Python Console" order="6" weight="0.1" />
|
||||
<window_info anchor="right" id="Endpoints" show_stripe_button="false" />
|
||||
<window_info anchor="right" id="SciView" show_stripe_button="false" />
|
||||
<window_info anchor="right" content_ui="combo" id="Notifications" order="0" weight="0.25" />
|
||||
<window_info anchor="right" id="AIAssistant" order="1" weight="0.25" />
|
||||
<window_info anchor="right" id="Database" order="2" weight="0.25" />
|
||||
<window_info anchor="right" id="Gradle" order="3" weight="0.25" />
|
||||
<window_info anchor="right" id="Maven" order="4" weight="0.25" />
|
||||
<window_info anchor="right" id="Plots" order="5" weight="0.1" />
|
||||
</layout>
|
||||
</component>
|
||||
<component name="RecentsManager">
|
||||
<key name="CopyFile.RECENT_KEYS">
|
||||
<recent name="$PROJECT_DIR$/trainer" />
|
||||
<recent name="$PROJECT_DIR$/configs/STDEN" />
|
||||
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||
</key>
|
||||
<key name="MoveFile.RECENT_KEYS">
|
||||
<recent name="$PROJECT_DIR$/models/STDEN" />
|
||||
</key>
|
||||
</component>
|
||||
<component name="RunManager" selected="Python.STGODE">
|
||||
<configuration name="STDEN" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="Project-I" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="SDK_NAME" value="TS" />
|
||||
<option name="WORKING_DIRECTORY" value="" />
|
||||
<option name="IS_MODULE_SDK" value="false" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="--config ./configs/STDEN/PEMS08.yaml" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<configuration name="STGODE" type="PythonConfigurationType" factoryName="Python">
|
||||
<module name="Project-I" />
|
||||
<option name="ENV_FILES" value="" />
|
||||
<option name="INTERPRETER_OPTIONS" value="" />
|
||||
<option name="PARENT_ENVS" value="true" />
|
||||
<envs>
|
||||
<env name="PYTHONUNBUFFERED" value="1" />
|
||||
</envs>
|
||||
<option name="SDK_HOME" value="" />
|
||||
<option name="SDK_NAME" value="TS" />
|
||||
<option name="WORKING_DIRECTORY" value="" />
|
||||
<option name="IS_MODULE_SDK" value="false" />
|
||||
<option name="ADD_CONTENT_ROOTS" value="true" />
|
||||
<option name="ADD_SOURCE_ROOTS" value="true" />
|
||||
<EXTENSION ID="PythonCoverageRunConfigurationExtension" runner="coverage.py" />
|
||||
<option name="SCRIPT_NAME" value="$PROJECT_DIR$/main.py" />
|
||||
<option name="PARAMETERS" value="--config ./configs/STGODE/PEMS08.yaml" />
|
||||
<option name="SHOW_COMMAND_LINE" value="false" />
|
||||
<option name="EMULATE_TERMINAL" value="false" />
|
||||
<option name="MODULE_MODE" value="false" />
|
||||
<option name="REDIRECT_INPUT" value="false" />
|
||||
<option name="INPUT_FILE" value="" />
|
||||
<method v="2" />
|
||||
</configuration>
|
||||
<list>
|
||||
<item itemvalue="Python.STDEN" />
|
||||
<item itemvalue="Python.STGODE" />
|
||||
</list>
|
||||
</component>
|
||||
<component name="SharedIndexes">
|
||||
<attachedChunks>
|
||||
<set>
|
||||
<option value="bundled-python-sdk-eebebe6c2be4-b11f5e8da5ad-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-233.15325.20" />
|
||||
</set>
|
||||
</attachedChunks>
|
||||
</component>
|
||||
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="应用程序级" UseSingleDictionary="true" transferred="true" />
|
||||
<component name="TaskManager">
|
||||
<task active="true" id="Default" summary="默认任务">
|
||||
<changelist id="8b1aea27-342c-41a7-b776-2aba4fceda0d" name="更改" comment="" />
|
||||
<created>1756727620810</created>
|
||||
<option name="number" value="Default" />
|
||||
<option name="presentableId" value="Default" />
|
||||
<updated>1756727620810</updated>
|
||||
<workItem from="1756727623101" duration="4721000" />
|
||||
<workItem from="1756856673845" duration="652000" />
|
||||
<workItem from="1756864144998" duration="1063000" />
|
||||
</task>
|
||||
<servers />
|
||||
</component>
|
||||
<component name="TypeScriptGeneratedFilesManager">
|
||||
<option name="version" value="3" />
|
||||
</component>
|
||||
<component name="XDebuggerManager">
|
||||
<breakpoint-manager>
|
||||
<breakpoints>
|
||||
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
||||
<url>file://$PROJECT_DIR$/models/STDEN/stden_model.py</url>
|
||||
<line>131</line>
|
||||
<option name="timeStamp" value="5" />
|
||||
</line-breakpoint>
|
||||
</breakpoints>
|
||||
</breakpoint-manager>
|
||||
</component>
|
||||
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
||||
<SUITE FILE_PATH="coverage/Project_I$main.coverage" NAME="STDEN 覆盖结果" MODIFIED="1756832980407" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
<SUITE FILE_PATH="coverage/Project_I$STGODE.coverage" NAME="STGODE 覆盖结果" MODIFIED="1756864828915" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="true" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="" />
|
||||
</component>
|
||||
</project>
|
||||
1
STDEN
1
STDEN
|
|
@ -1 +0,0 @@
|
|||
Subproject commit e50a1ba6d70528b3e684c85f316aed05bb5085f2
|
||||
|
|
@ -1,60 +0,0 @@
|
|||
basic:
|
||||
device: cuda:0
|
||||
dataset: PEMS08
|
||||
model: STGODE
|
||||
mode: train
|
||||
seed: 2025
|
||||
|
||||
data:
|
||||
dataset_dir: data/PEMS08
|
||||
val_batch_size: 32
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
num_nodes: 170
|
||||
batch_size: 64
|
||||
input_dim: 1
|
||||
lag: 12
|
||||
horizon: 12
|
||||
val_ratio: 0.2
|
||||
test_ratio: 0.2
|
||||
tod: False
|
||||
normalizer: std
|
||||
column_wise: False
|
||||
default_graph: True
|
||||
add_time_in_day: True
|
||||
add_day_in_week: True
|
||||
steps_per_day: 24
|
||||
days_per_week: 7
|
||||
|
||||
model:
|
||||
input_dim: 1
|
||||
output_dim: 1
|
||||
history: 12
|
||||
horizon: 12
|
||||
num_features: 1
|
||||
rnn_units: 64
|
||||
sigma1: 0.1
|
||||
sigma2: 10
|
||||
thres1: 0.6
|
||||
thres2: 0.5
|
||||
|
||||
|
||||
train:
|
||||
loss: mae
|
||||
batch_size: 64
|
||||
epochs: 100
|
||||
lr_init: 0.003
|
||||
mape_thresh: 0.001
|
||||
mae_thresh: None
|
||||
debug: False
|
||||
output_dim: 1
|
||||
weight_decay: 0
|
||||
lr_decay: False
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: "5,20,40,70"
|
||||
early_stop: True
|
||||
early_stop_patience: 15
|
||||
grad_norm: False
|
||||
max_grad_norm: 5
|
||||
real_value: True
|
||||
log_step: 3000
|
||||
|
||||
|
|
@ -1,180 +0,0 @@
|
|||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from models.STGODE.odegcn import ODEG
|
||||
from models.STGODE.adj import get_A_hat
|
||||
|
||||
class Chomp1d(nn.Module):
|
||||
"""
|
||||
extra dimension will be added by padding, remove it
|
||||
"""
|
||||
|
||||
def __init__(self, chomp_size):
|
||||
super(Chomp1d, self).__init__()
|
||||
self.chomp_size = chomp_size
|
||||
|
||||
def forward(self, x):
|
||||
return x[:, :, :, :-self.chomp_size].contiguous()
|
||||
|
||||
|
||||
class TemporalConvNet(nn.Module):
|
||||
"""
|
||||
time dilation convolution
|
||||
"""
|
||||
|
||||
def __init__(self, num_inputs, num_channels, kernel_size=2, dropout=0.2):
|
||||
"""
|
||||
Args:
|
||||
num_inputs : channel's number of input data's feature
|
||||
num_channels : numbers of data feature tranform channels, the last is the output channel
|
||||
kernel_size : using 1d convolution, so the real kernel is (1, kernel_size)
|
||||
"""
|
||||
super(TemporalConvNet, self).__init__()
|
||||
layers = []
|
||||
num_levels = len(num_channels)
|
||||
for i in range(num_levels):
|
||||
dilation_size = 2 ** i
|
||||
in_channels = num_inputs if i == 0 else num_channels[i - 1]
|
||||
out_channels = num_channels[i]
|
||||
padding = (kernel_size - 1) * dilation_size
|
||||
self.conv = nn.Conv2d(in_channels, out_channels, (1, kernel_size), dilation=(1, dilation_size),
|
||||
padding=(0, padding))
|
||||
self.conv.weight.data.normal_(0, 0.01)
|
||||
self.chomp = Chomp1d(padding)
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
|
||||
layers += [nn.Sequential(self.conv, self.chomp, self.relu, self.dropout)]
|
||||
|
||||
self.network = nn.Sequential(*layers)
|
||||
self.downsample = nn.Conv2d(num_inputs, num_channels[-1], (1, 1)) if num_inputs != num_channels[-1] else None
|
||||
if self.downsample:
|
||||
self.downsample.weight.data.normal_(0, 0.01)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
like ResNet
|
||||
Args:
|
||||
X : input data of shape (B, N, T, F)
|
||||
"""
|
||||
# permute shape to (B, F, N, T)
|
||||
y = x.permute(0, 3, 1, 2)
|
||||
y = F.relu(self.network(y) + self.downsample(y) if self.downsample else y)
|
||||
y = y.permute(0, 2, 3, 1)
|
||||
return y
|
||||
|
||||
|
||||
class GCN(nn.Module):
|
||||
def __init__(self, A_hat, in_channels, out_channels, ):
|
||||
super(GCN, self).__init__()
|
||||
self.A_hat = A_hat
|
||||
self.theta = nn.Parameter(torch.FloatTensor(in_channels, out_channels))
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
stdv = 1. / math.sqrt(self.theta.shape[1])
|
||||
self.theta.data.uniform_(-stdv, stdv)
|
||||
|
||||
def forward(self, X):
|
||||
y = torch.einsum('ij, kjlm-> kilm', self.A_hat, X)
|
||||
return F.relu(torch.einsum('kjlm, mn->kjln', y, self.theta))
|
||||
|
||||
|
||||
class STGCNBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_nodes, A_hat):
|
||||
"""
|
||||
Args:
|
||||
in_channels: Number of input features at each node in each time step.
|
||||
out_channels: a list of feature channels in timeblock, the last is output feature channel
|
||||
num_nodes: Number of nodes in the graph
|
||||
A_hat: the normalized adjacency matrix
|
||||
"""
|
||||
super(STGCNBlock, self).__init__()
|
||||
self.A_hat = A_hat
|
||||
self.temporal1 = TemporalConvNet(num_inputs=in_channels,
|
||||
num_channels=out_channels)
|
||||
self.odeg = ODEG(out_channels[-1], 12, A_hat, time=6)
|
||||
self.temporal2 = TemporalConvNet(num_inputs=out_channels[-1],
|
||||
num_channels=out_channels)
|
||||
self.batch_norm = nn.BatchNorm2d(num_nodes)
|
||||
|
||||
def forward(self, X):
|
||||
"""
|
||||
Args:
|
||||
X: Input data of shape (batch_size, num_nodes, num_timesteps, num_features)
|
||||
Return:
|
||||
Output data of shape(batch_size, num_nodes, num_timesteps, out_channels[-1])
|
||||
"""
|
||||
t = self.temporal1(X)
|
||||
t = self.odeg(t)
|
||||
t = self.temporal2(F.relu(t))
|
||||
|
||||
return self.batch_norm(t)
|
||||
|
||||
|
||||
class ODEGCN(nn.Module):
|
||||
""" the overall network framework """
|
||||
|
||||
def __init__(self, config):
|
||||
"""
|
||||
Args:
|
||||
num_nodes : number of nodes in the graph
|
||||
num_features : number of features at each node in each time step
|
||||
num_timesteps_input : number of past time steps fed into the network
|
||||
num_timesteps_output : desired number of future time steps output by the network
|
||||
A_sp_hat : nomarlized adjacency spatial matrix
|
||||
A_se_hat : nomarlized adjacency semantic matrix
|
||||
"""
|
||||
|
||||
super(ODEGCN, self).__init__()
|
||||
args = config['model']
|
||||
num_nodes = config['data']['num_nodes']
|
||||
num_features = args['num_features']
|
||||
num_timesteps_input = args['history']
|
||||
num_timesteps_output = args['horizon']
|
||||
A_sp_hat, A_se_hat = get_A_hat(config)
|
||||
|
||||
# spatial graph
|
||||
self.sp_blocks = nn.ModuleList(
|
||||
[nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
|
||||
num_nodes=num_nodes, A_hat=A_sp_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
|
||||
num_nodes=num_nodes, A_hat=A_sp_hat)) for _ in range(3)
|
||||
])
|
||||
# semantic graph
|
||||
self.se_blocks = nn.ModuleList([nn.Sequential(
|
||||
STGCNBlock(in_channels=num_features, out_channels=[64, 32, 64],
|
||||
num_nodes=num_nodes, A_hat=A_se_hat),
|
||||
STGCNBlock(in_channels=64, out_channels=[64, 32, 64],
|
||||
num_nodes=num_nodes, A_hat=A_se_hat)) for _ in range(3)
|
||||
])
|
||||
|
||||
self.pred = nn.Sequential(
|
||||
nn.Linear(num_timesteps_input * 64, num_timesteps_output * 32),
|
||||
nn.ReLU(),
|
||||
nn.Linear(num_timesteps_output * 32, num_timesteps_output)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Args:
|
||||
x : input data of shape (batch_size, num_nodes, num_timesteps, num_features) == (B, N, T, F)
|
||||
Returns:
|
||||
prediction for future of shape (batch_size, num_nodes, num_timesteps_output)
|
||||
"""
|
||||
x = x[..., 0:1].permute(0, 2, 1, 3)
|
||||
outs = []
|
||||
# spatial graph
|
||||
for blk in self.sp_blocks:
|
||||
outs.append(blk(x))
|
||||
# semantic graph
|
||||
for blk in self.se_blocks:
|
||||
outs.append(blk(x))
|
||||
outs = torch.stack(outs)
|
||||
x = torch.max(outs, dim=0)[0]
|
||||
x = x.reshape((x.shape[0], x.shape[1], -1))
|
||||
|
||||
return self.pred(x).permute(0,2,1).unsqueeze(dim=-1)
|
||||
|
|
@ -1,132 +0,0 @@
|
|||
import csv
|
||||
import os
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from fastdtw import fastdtw
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
|
||||
files = {
|
||||
358: ['PEMS03/PEMS03.npz', 'PEMS03/PEMS03.csv'],
|
||||
307: ['PEMS04/PEMS04.npz', 'PEMS04/PEMS04.csv'],
|
||||
883: ['PEMS07/PEMS07.npz', 'PEMS07/PEMS07.csv'],
|
||||
170: ['PEMS08/PEMS08.npz', 'PEMS08/PEMS08.csv'],
|
||||
# 'pemsbay': ['PEMSBAY/pems_bay.npz', 'PEMSBAY/distance.csv'],
|
||||
# 'pemsD7M': ['PEMSD7M/PEMSD7M.npz', 'PEMSD7M/distance.csv'],
|
||||
# 'pemsD7L': ['PEMSD7L/PEMSD7L.npz', 'PEMSD7L/distance.csv']
|
||||
}
|
||||
|
||||
|
||||
def get_A_hat(config):
|
||||
"""read data, generate spatial adjacency matrix and semantic adjacency matrix by dtw
|
||||
|
||||
Args:
|
||||
sigma1: float, default=0.1, sigma for the semantic matrix
|
||||
sigma2: float, default=10, sigma for the spatial matrix
|
||||
thres1: float, default=0.6, the threshold for the semantic matrix
|
||||
thres2: float, default=0.5, the threshold for the spatial matrix
|
||||
|
||||
Returns:
|
||||
data: tensor, T * N * 1
|
||||
dtw_matrix: array, semantic adjacency matrix
|
||||
sp_matrix: array, spatial adjacency matrix
|
||||
"""
|
||||
file_path = config['data']['graph_pkl_filename']
|
||||
filename = config['basic']['dataset']
|
||||
dataset_path = config['data']['dataset_dir']
|
||||
args = config['model']
|
||||
|
||||
data = np.load(file_path)
|
||||
data = np.nan_to_num(data, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
num_node = data.shape[1]
|
||||
mean_value = np.mean(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||
std_value = np.std(data, axis=(0, 1)).reshape(1, 1, -1)
|
||||
data = (data - mean_value) / std_value
|
||||
|
||||
# 计算dtw_distance, 如果存在缓存则直接读取缓存
|
||||
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy'):
|
||||
data_mean = np.mean([data[:, :, 0][24 * 12 * i: 24 * 12 * (i + 1)] for i in range(data.shape[0] // (24 * 12))],
|
||||
axis=0)
|
||||
data_mean = data_mean.squeeze().T
|
||||
dtw_distance = np.zeros((num_node, num_node))
|
||||
for i in tqdm(range(num_node)):
|
||||
for j in range(i, num_node):
|
||||
dtw_distance[i][j] = fastdtw(data_mean[i], data_mean[j], radius=6)[0]
|
||||
for i in range(num_node):
|
||||
for j in range(i):
|
||||
dtw_distance[i][j] = dtw_distance[j][i]
|
||||
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy', dtw_distance)
|
||||
|
||||
dist_matrix = np.load(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_dtw_distance.npy')
|
||||
|
||||
mean = np.mean(dist_matrix)
|
||||
std = np.std(dist_matrix)
|
||||
dist_matrix = (dist_matrix - mean) / std
|
||||
sigma = args['sigma1']
|
||||
dist_matrix = np.exp(-dist_matrix ** 2 / sigma ** 2)
|
||||
dtw_matrix = np.zeros_like(dist_matrix)
|
||||
dtw_matrix[dist_matrix > args['thres1']] = 1
|
||||
|
||||
# 计算spatial_distance, 如果存在缓存则直接读取缓存
|
||||
if not os.path.exists(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy'):
|
||||
if num_node == 358:
|
||||
with open(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}.txt', 'r') as f:
|
||||
id_dict = {int(i): idx for idx, i in enumerate(f.read().strip().split('\n'))} # 建立映射列表
|
||||
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||
for _, row in df.iterrows():
|
||||
start = int(id_dict[row[0]])
|
||||
end = int(id_dict[row[1]])
|
||||
dist_matrix[start][end] = float(row[2])
|
||||
dist_matrix[end][start] = float(row[2])
|
||||
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
||||
else:
|
||||
# 使用 pandas 读取 CSV 文件,跳过标题行
|
||||
df = pd.read_csv(f'{dataset_path}/{filename}.csv', skiprows=1, header=None)
|
||||
dist_matrix = np.zeros((num_node, num_node)) + float('inf')
|
||||
for _, row in df.iterrows():
|
||||
start = int(row[0])
|
||||
end = int(row[1])
|
||||
dist_matrix[start][end] = float(row[2])
|
||||
dist_matrix[end][start] = float(row[2])
|
||||
np.save(f'data/PEMS0{filename[-1]}/PEMS0{filename[-1]}_spatial_distance.npy', dist_matrix)
|
||||
# normalization
|
||||
std = np.std(dist_matrix[dist_matrix != float('inf')])
|
||||
mean = np.mean(dist_matrix[dist_matrix != float('inf')])
|
||||
dist_matrix = (dist_matrix - mean) / std
|
||||
sigma = args['sigma2']
|
||||
sp_matrix = np.exp(- dist_matrix ** 2 / sigma ** 2)
|
||||
sp_matrix[sp_matrix < args['thres2']] = 0
|
||||
|
||||
return (get_normalized_adj(dtw_matrix).to(config['basic']['device']),
|
||||
get_normalized_adj(sp_matrix).to(config['basic']['device']))
|
||||
|
||||
|
||||
def get_normalized_adj(A):
|
||||
"""
|
||||
Returns a tensor, the degree normalized adjacency matrix.
|
||||
"""
|
||||
alpha = 0.8
|
||||
D = np.array(np.sum(A, axis=1)).reshape((-1,))
|
||||
D[D <= 10e-5] = 10e-5 # Prevent infs
|
||||
diag = np.reciprocal(np.sqrt(D))
|
||||
A_wave = np.multiply(np.multiply(diag.reshape((-1, 1)), A),
|
||||
diag.reshape((1, -1)))
|
||||
A_reg = alpha / 2 * (np.eye(A.shape[0]) + A_wave)
|
||||
return torch.from_numpy(A_reg.astype(np.float32))
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
config = {
|
||||
'sigma1': 0.1,
|
||||
'sigma2': 10,
|
||||
'thres1': 0.6,
|
||||
'thres2': 0.5,
|
||||
'device': 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
||||
}
|
||||
|
||||
for nodes in [358, 170, 883]:
|
||||
args = {'num_nodes': nodes, **config}
|
||||
get_A_hat(args)
|
||||
|
|
@ -1,74 +0,0 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
# Whether use adjoint method or not.
|
||||
adjoint = False
|
||||
if adjoint:
|
||||
from torchdiffeq import odeint_adjoint as odeint
|
||||
else:
|
||||
from torchdiffeq import odeint
|
||||
|
||||
|
||||
# Define the ODE function.
|
||||
# Input:
|
||||
# --- t: A tensor with shape [], meaning the current time.
|
||||
# --- x: A tensor with shape [#batches, dims], meaning the value of x at t.
|
||||
# Output:
|
||||
# --- dx/dt: A tensor with shape [#batches, dims], meaning the derivative of x at t.
|
||||
class ODEFunc(nn.Module):
|
||||
|
||||
def __init__(self, feature_dim, temporal_dim, adj):
|
||||
super(ODEFunc, self).__init__()
|
||||
self.adj = adj
|
||||
self.x0 = None
|
||||
self.alpha = nn.Parameter(0.8 * torch.ones(adj.shape[1]))
|
||||
self.beta = 0.6
|
||||
self.w = nn.Parameter(torch.eye(feature_dim))
|
||||
self.d = nn.Parameter(torch.zeros(feature_dim) + 1)
|
||||
self.w2 = nn.Parameter(torch.eye(temporal_dim))
|
||||
self.d2 = nn.Parameter(torch.zeros(temporal_dim) + 1)
|
||||
|
||||
def forward(self, t, x):
|
||||
alpha = torch.sigmoid(self.alpha).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
|
||||
xa = torch.einsum('ij, kjlm->kilm', self.adj, x)
|
||||
|
||||
# ensure the eigenvalues to be less than 1
|
||||
d = torch.clamp(self.d, min=0, max=1)
|
||||
w = torch.mm(self.w * d, torch.t(self.w))
|
||||
xw = torch.einsum('ijkl, lm->ijkm', x, w)
|
||||
|
||||
d2 = torch.clamp(self.d2, min=0, max=1)
|
||||
w2 = torch.mm(self.w2 * d2, torch.t(self.w2))
|
||||
xw2 = torch.einsum('ijkl, km->ijml', x, w2)
|
||||
|
||||
f = alpha / 2 * xa - x + xw - x + xw2 - x + self.x0
|
||||
return f
|
||||
|
||||
|
||||
class ODEblock(nn.Module):
|
||||
def __init__(self, odefunc, t=torch.tensor([0,1])):
|
||||
super(ODEblock, self).__init__()
|
||||
self.t = t
|
||||
self.odefunc = odefunc
|
||||
|
||||
def set_x0(self, x0):
|
||||
self.odefunc.x0 = x0.clone().detach()
|
||||
|
||||
def forward(self, x):
|
||||
t = self.t.type_as(x)
|
||||
z = odeint(self.odefunc, x, t, method='euler')[1]
|
||||
return z
|
||||
|
||||
|
||||
# Define the ODEGCN model.
|
||||
class ODEG(nn.Module):
|
||||
def __init__(self, feature_dim, temporal_dim, adj, time):
|
||||
super(ODEG, self).__init__()
|
||||
self.odeblock = ODEblock(ODEFunc(feature_dim, temporal_dim, adj), t=torch.tensor([0, time]))
|
||||
|
||||
def forward(self, x):
|
||||
self.odeblock.set_x0(x)
|
||||
z = self.odeblock(x)
|
||||
return F.relu(z)
|
||||
|
|
@ -1,12 +1,8 @@
|
|||
from models.STDEN.stden_model import STDENModel
|
||||
from models.STGODE.STGODE import ODEGCN
|
||||
|
||||
def model_selector(config):
|
||||
model_name = config['basic']['model']
|
||||
model = None
|
||||
match model_name:
|
||||
case 'STDEN':
|
||||
model = STDENModel(config)
|
||||
case 'STGODE':
|
||||
model = ODEGCN(config)
|
||||
case 'STDEN': model = STDENModel(config)
|
||||
return model
|
||||
Binary file not shown.
BIN
test_spatial.npy
BIN
test_spatial.npy
Binary file not shown.
Loading…
Reference in New Issue