Compare commits
2 Commits
19a02ba7ae
...
66a23ffbbb
| Author | SHA1 | Date |
|---|---|---|
|
|
66a23ffbbb | |
|
|
df8c573f4c |
|
|
@ -160,3 +160,4 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
STDEN/
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<Languages>
|
||||||
|
<language minSize="136" name="Python" />
|
||||||
|
</Languages>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="8">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,296 @@
|
||||||
|
from,to,cost
|
||||||
|
9,153,310.6
|
||||||
|
153,62,330.9
|
||||||
|
62,111,332.9
|
||||||
|
111,11,324.2
|
||||||
|
11,28,336.0
|
||||||
|
28,169,133.7
|
||||||
|
138,135,354.7
|
||||||
|
135,133,387.9
|
||||||
|
133,163,337.1
|
||||||
|
163,20,352.0
|
||||||
|
20,19,420.8
|
||||||
|
19,14,351.3
|
||||||
|
14,39,340.2
|
||||||
|
39,164,350.3
|
||||||
|
164,167,365.2
|
||||||
|
167,70,359.0
|
||||||
|
70,59,388.2
|
||||||
|
59,58,305.7
|
||||||
|
58,67,294.4
|
||||||
|
67,66,299.5
|
||||||
|
66,55,313.3
|
||||||
|
55,53,332.1
|
||||||
|
53,150,278.9
|
||||||
|
150,61,308.4
|
||||||
|
61,64,311.4
|
||||||
|
64,63,243.6
|
||||||
|
47,65,372.8
|
||||||
|
65,48,319.4
|
||||||
|
48,49,309.7
|
||||||
|
49,54,320.5
|
||||||
|
54,56,318.3
|
||||||
|
56,57,297.9
|
||||||
|
57,68,293.5
|
||||||
|
68,69,342.5
|
||||||
|
69,60,318.0
|
||||||
|
60,17,305.9
|
||||||
|
17,5,321.4
|
||||||
|
5,18,402.2
|
||||||
|
18,22,447.4
|
||||||
|
22,30,377.5
|
||||||
|
30,29,417.7
|
||||||
|
29,21,360.8
|
||||||
|
21,132,407.6
|
||||||
|
132,134,386.9
|
||||||
|
134,136,350.2
|
||||||
|
123,121,326.3
|
||||||
|
121,140,385.2
|
||||||
|
140,118,393.0
|
||||||
|
118,96,296.7
|
||||||
|
96,94,398.2
|
||||||
|
94,86,337.1
|
||||||
|
86,78,473.8
|
||||||
|
78,46,353.4
|
||||||
|
46,152,385.7
|
||||||
|
152,157,350.0
|
||||||
|
157,35,354.4
|
||||||
|
35,77,356.1
|
||||||
|
77,52,354.2
|
||||||
|
52,3,357.8
|
||||||
|
3,16,382.4
|
||||||
|
16,0,55.7
|
||||||
|
42,12,335.1
|
||||||
|
12,139,328.8
|
||||||
|
139,168,412.6
|
||||||
|
168,154,337.3
|
||||||
|
154,143,370.7
|
||||||
|
143,10,6.3
|
||||||
|
107,105,354.6
|
||||||
|
105,104,386.9
|
||||||
|
104,148,362.1
|
||||||
|
148,97,316.3
|
||||||
|
97,101,380.7
|
||||||
|
101,137,361.4
|
||||||
|
137,102,365.5
|
||||||
|
102,24,375.5
|
||||||
|
24,166,312.2
|
||||||
|
129,156,256.1
|
||||||
|
156,33,329.1
|
||||||
|
33,32,356.5
|
||||||
|
91,89,405.6
|
||||||
|
89,147,347.0
|
||||||
|
147,15,351.7
|
||||||
|
15,44,339.5
|
||||||
|
44,41,350.8
|
||||||
|
41,43,322.6
|
||||||
|
43,100,338.9
|
||||||
|
100,83,347.9
|
||||||
|
83,87,327.2
|
||||||
|
87,88,321.0
|
||||||
|
88,75,335.8
|
||||||
|
75,51,384.8
|
||||||
|
51,73,391.1
|
||||||
|
73,71,289.3
|
||||||
|
31,155,260.0
|
||||||
|
155,34,320.4
|
||||||
|
34,128,393.3
|
||||||
|
145,115,399.4
|
||||||
|
115,112,328.1
|
||||||
|
112,8,469.4
|
||||||
|
8,117,816.2
|
||||||
|
117,125,397.1
|
||||||
|
125,127,372.7
|
||||||
|
127,109,380.5
|
||||||
|
109,161,355.5
|
||||||
|
161,110,367.7
|
||||||
|
110,160,102.0
|
||||||
|
72,159,342.9
|
||||||
|
159,50,383.3
|
||||||
|
50,74,354.1
|
||||||
|
74,82,350.2
|
||||||
|
82,81,335.4
|
||||||
|
81,99,391.6
|
||||||
|
99,84,354.9
|
||||||
|
84,13,306.4
|
||||||
|
13,40,327.4
|
||||||
|
40,162,413.9
|
||||||
|
162,108,301.9
|
||||||
|
108,146,317.8
|
||||||
|
146,85,376.6
|
||||||
|
85,90,347.0
|
||||||
|
26,27,341.6
|
||||||
|
27,6,359.4
|
||||||
|
6,149,417.8
|
||||||
|
149,126,388.0
|
||||||
|
126,124,384.3
|
||||||
|
124,7,763.3
|
||||||
|
7,114,323.1
|
||||||
|
114,113,351.6
|
||||||
|
113,116,411.9
|
||||||
|
116,144,262.0
|
||||||
|
25,103,350.2
|
||||||
|
103,23,376.3
|
||||||
|
23,165,396.4
|
||||||
|
165,38,381.0
|
||||||
|
38,92,368.0
|
||||||
|
92,37,336.3
|
||||||
|
37,130,357.8
|
||||||
|
130,106,532.3
|
||||||
|
106,131,166.5
|
||||||
|
1,2,371.6
|
||||||
|
2,4,338.1
|
||||||
|
4,76,429.0
|
||||||
|
76,36,366.1
|
||||||
|
36,158,344.5
|
||||||
|
158,151,350.1
|
||||||
|
151,45,358.8
|
||||||
|
45,93,340.9
|
||||||
|
93,80,329.9
|
||||||
|
80,79,384.1
|
||||||
|
79,95,335.7
|
||||||
|
95,98,320.9
|
||||||
|
98,119,340.3
|
||||||
|
119,120,376.8
|
||||||
|
120,122,393.1
|
||||||
|
122,141,428.7
|
||||||
|
141,142,359.3
|
||||||
|
30,165,379.6
|
||||||
|
165,29,41.7
|
||||||
|
29,38,343.3
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
17,153,375.6
|
||||||
|
153,5,256.3
|
||||||
|
153,62,330.9
|
||||||
|
18,6,499.4
|
||||||
|
6,22,254.0
|
||||||
|
22,149,185.4
|
||||||
|
22,4,257.9
|
||||||
|
4,30,236.8
|
||||||
|
30,76,307.0
|
||||||
|
95,98,320.9
|
||||||
|
98,144,45.1
|
||||||
|
45,93,340.9
|
||||||
|
93,106,112.2
|
||||||
|
162,151,113.6
|
||||||
|
151,108,192.9
|
||||||
|
108,45,359.8
|
||||||
|
146,92,311.2
|
||||||
|
92,85,343.9
|
||||||
|
85,37,373.2
|
||||||
|
13,169,326.2
|
||||||
|
169,40,96.1
|
||||||
|
124,13,460.7
|
||||||
|
13,7,305.5
|
||||||
|
7,40,624.1
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
90,132,152.2
|
||||||
|
26,32,106.7
|
||||||
|
9,129,148.3
|
||||||
|
129,153,219.6
|
||||||
|
31,26,116.0
|
||||||
|
26,155,270.7
|
||||||
|
9,128,142.2
|
||||||
|
128,153,215.0
|
||||||
|
153,167,269.7
|
||||||
|
167,62,64.8
|
||||||
|
62,70,332.6
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
44,169,397.8
|
||||||
|
169,41,124.0
|
||||||
|
44,124,375.7
|
||||||
|
124,41,243.9
|
||||||
|
41,7,519.4
|
||||||
|
6,14,289.3
|
||||||
|
14,149,259.0
|
||||||
|
149,39,206.9
|
||||||
|
144,98,45.1
|
||||||
|
19,4,326.8
|
||||||
|
4,14,178.6
|
||||||
|
14,76,299.0
|
||||||
|
15,151,136.4
|
||||||
|
151,44,203.1
|
||||||
|
45,106,260.6
|
||||||
|
106,93,112.2
|
||||||
|
20,165,132.5
|
||||||
|
165,19,289.2
|
||||||
|
89,92,323.2
|
||||||
|
92,147,321.9
|
||||||
|
147,37,48.2
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
104,147,277.5
|
||||||
|
147,148,84.7
|
||||||
|
20,101,201.2
|
||||||
|
101,19,534.4
|
||||||
|
19,137,245.5
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
44,42,342.3
|
||||||
|
42,41,102.5
|
||||||
|
44,8,789.1
|
||||||
|
8,41,657.4
|
||||||
|
41,117,160.5
|
||||||
|
168,167,172.4
|
||||||
|
167,154,165.2
|
||||||
|
143,128,81.9
|
||||||
|
128,10,88.2
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
15,152,135.0
|
||||||
|
152,44,204.6
|
||||||
|
19,77,320.7
|
||||||
|
77,14,299.8
|
||||||
|
14,52,127.6
|
||||||
|
14,127,314.8
|
||||||
|
127,39,280.4
|
||||||
|
39,109,237.0
|
||||||
|
31,160,116.5
|
||||||
|
160,155,272.4
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
32,160,107.7
|
||||||
|
72,162,3274.4
|
||||||
|
162,13,554.5
|
||||||
|
162,40,413.9
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
13,42,319.8
|
||||||
|
42,40,40.7
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
8,13,450.3
|
||||||
|
13,117,378.5
|
||||||
|
117,40,64.0
|
||||||
|
46,162,391.6
|
||||||
|
162,152,115.3
|
||||||
|
152,108,191.4
|
||||||
|
104,108,375.9
|
||||||
|
108,148,311.6
|
||||||
|
148,146,80.0
|
||||||
|
21,90,396.9
|
||||||
|
90,132,152.2
|
||||||
|
101,29,252.3
|
||||||
|
29,137,110.7
|
||||||
|
77,22,353.8
|
||||||
|
22,52,227.8
|
||||||
|
52,30,186.6
|
||||||
|
127,18,425.2
|
||||||
|
18,109,439.1
|
||||||
|
109,22,135.5
|
||||||
|
168,17,232.7
|
||||||
|
17,154,294.2
|
||||||
|
154,5,166.3
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def load_dataset(config):
|
def load_dataset(config):
|
||||||
|
|
||||||
dataset_name = config['basic']['dataset']
|
dataset_name = config['basic']['dataset']
|
||||||
node_num = config['data']['num_nodes']
|
node_num = config['data']['num_nodes']
|
||||||
input_dim = config['data']['input_dim']
|
input_dim = config['data']['input_dim']
|
||||||
|
|
@ -10,4 +11,8 @@ def load_dataset(config):
|
||||||
case 'EcoSolar':
|
case 'EcoSolar':
|
||||||
data_path = os.path.join('./data/EcoSolar.npy')
|
data_path = os.path.join('./data/EcoSolar.npy')
|
||||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||||
|
case 'PEMS08':
|
||||||
|
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||||
|
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def load_graph(config):
|
||||||
|
dataset_path = config['data']['graph_pkl_filename']
|
||||||
|
graph = np.load(dataset_path)
|
||||||
|
# 将inf值填充为0
|
||||||
|
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 100
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 100
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
2
main.py
2
main.py
|
|
@ -3,8 +3,6 @@
|
||||||
时空数据深度学习预测项目主程序
|
时空数据深度学习预测项目主程序
|
||||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from utils.args_reader import config_loader
|
from utils.args_reader import config_loader
|
||||||
import utils.init as init
|
import utils.init as init
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from models.STDEN.stden_model import STDENModel
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
model = None
|
model = None
|
||||||
|
match model_name:
|
||||||
|
case 'STDEN': model = STDENModel(config)
|
||||||
return model
|
return model
|
||||||
|
|
@ -0,0 +1,576 @@
|
||||||
|
import math
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import copy
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
class Trainer:
|
||||||
|
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
|
||||||
|
scalers, logger, lr_scheduler=None):
|
||||||
|
self.model = model
|
||||||
|
self.loss = loss
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.train_loader = train_loader
|
||||||
|
self.val_loader = val_loader
|
||||||
|
self.test_loader = test_loader
|
||||||
|
self.scalers = scalers # 现在是多个标准化器的列表
|
||||||
|
self.args = config['train']
|
||||||
|
self.logger = logger
|
||||||
|
self.args['device'] = config['basic']['device']
|
||||||
|
self.lr_scheduler = lr_scheduler
|
||||||
|
self.train_per_epoch = len(train_loader)
|
||||||
|
self.val_per_epoch = len(val_loader) if val_loader else 0
|
||||||
|
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
|
||||||
|
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
||||||
|
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
self.c = []
|
||||||
|
self.res, self.keys = [], []
|
||||||
|
|
||||||
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
|
if mode == 'train':
|
||||||
|
self.model.train()
|
||||||
|
optimizer_step = True
|
||||||
|
else:
|
||||||
|
self.model.eval()
|
||||||
|
optimizer_step = False
|
||||||
|
|
||||||
|
total_loss = 0
|
||||||
|
epoch_time = time.time()
|
||||||
|
|
||||||
|
# 清空nfe数据收集
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.clear()
|
||||||
|
|
||||||
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
|
label = target[..., :self.args['output_dim']]
|
||||||
|
output, fe = self.model(data)
|
||||||
|
|
||||||
|
if self.args['real_value']:
|
||||||
|
# 只对输出维度进行反归一化
|
||||||
|
output = self._inverse_transform_output(output)
|
||||||
|
|
||||||
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.append([*fe, loss.item()])
|
||||||
|
# 记录FE信息
|
||||||
|
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||||
|
|
||||||
|
if optimizer_step and self.optimizer is not None:
|
||||||
|
self.optimizer.zero_grad()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if self.args['grad_norm']:
|
||||||
|
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||||
|
self.optimizer.step()
|
||||||
|
|
||||||
|
total_loss += loss.item()
|
||||||
|
|
||||||
|
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||||
|
self.logger.info(
|
||||||
|
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
|
||||||
|
|
||||||
|
# 更新 tqdm 的进度
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_postfix(loss=loss.item())
|
||||||
|
|
||||||
|
avg_loss = total_loss / len(dataloader)
|
||||||
|
self.logger.logger.info(
|
||||||
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
||||||
|
self.keys.append(epoch)
|
||||||
|
|
||||||
|
return avg_loss
|
||||||
|
|
||||||
|
def _inverse_transform_output(self, output):
|
||||||
|
"""
|
||||||
|
只对输出维度进行反归一化
|
||||||
|
假设输出数据形状为 [batch, horizon, nodes, features]
|
||||||
|
只对前output_dim个特征进行反归一化
|
||||||
|
"""
|
||||||
|
if not self.args['real_value']:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# 获取输出维度的数量
|
||||||
|
output_dim = self.args['output_dim']
|
||||||
|
|
||||||
|
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||||
|
if output_dim <= len(self.scalers):
|
||||||
|
# 对每个输出特征分别进行反归一化
|
||||||
|
for feature_idx in range(output_dim):
|
||||||
|
if feature_idx < len(self.scalers):
|
||||||
|
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||||
|
for feature_idx in range(len(self.scalers)):
|
||||||
|
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def train_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.train_loader, 'train')
|
||||||
|
|
||||||
|
def val_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
|
||||||
|
|
||||||
|
def test_epoch(self, epoch):
|
||||||
|
return self._run_epoch(epoch, self.test_loader, 'test')
|
||||||
|
|
||||||
|
def train(self):
|
||||||
|
best_model, best_test_model = None, None
|
||||||
|
best_loss, best_test_loss = float('inf'), float('inf')
|
||||||
|
not_improved_count = 0
|
||||||
|
|
||||||
|
self.logger.logger.info("Training process started")
|
||||||
|
for epoch in range(1, self.args['epochs'] + 1):
|
||||||
|
train_epoch_loss = self.train_epoch(epoch)
|
||||||
|
val_epoch_loss = self.val_epoch(epoch)
|
||||||
|
test_epoch_loss = self.test_epoch(epoch)
|
||||||
|
|
||||||
|
if train_epoch_loss > 1e6:
|
||||||
|
self.logger.logger.warning('Gradient explosion detected. Ending...')
|
||||||
|
break
|
||||||
|
|
||||||
|
if val_epoch_loss < best_loss:
|
||||||
|
best_loss = val_epoch_loss
|
||||||
|
not_improved_count = 0
|
||||||
|
best_model = copy.deepcopy(self.model.state_dict())
|
||||||
|
torch.save(best_model, self.best_path)
|
||||||
|
self.logger.logger.info('Best validation model saved!')
|
||||||
|
else:
|
||||||
|
not_improved_count += 1
|
||||||
|
|
||||||
|
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
|
||||||
|
self.logger.logger.info(
|
||||||
|
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
||||||
|
break
|
||||||
|
|
||||||
|
if test_epoch_loss < best_test_loss:
|
||||||
|
best_test_loss = test_epoch_loss
|
||||||
|
best_test_model = copy.deepcopy(self.model.state_dict())
|
||||||
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
||||||
|
self._save_nfe_data()
|
||||||
|
|
||||||
|
if not self.args['debug']:
|
||||||
|
torch.save(best_model, self.best_path)
|
||||||
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||||
|
|
||||||
|
self._finalize_training(best_model, best_test_model)
|
||||||
|
|
||||||
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
|
self.model.load_state_dict(best_model)
|
||||||
|
self.logger.logger.info("Testing on best validation model")
|
||||||
|
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
|
||||||
|
|
||||||
|
self.model.load_state_dict(best_test_model)
|
||||||
|
self.logger.logger.info("Testing on best test model")
|
||||||
|
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
|
||||||
|
if path:
|
||||||
|
checkpoint = torch.load(path)
|
||||||
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
|
model.to(args['device'])
|
||||||
|
|
||||||
|
model.eval()
|
||||||
|
y_pred, y_true = [], []
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
c = []
|
||||||
|
|
||||||
|
with torch.no_grad():
|
||||||
|
for data, target in data_loader:
|
||||||
|
label = target[..., :args['output_dim']]
|
||||||
|
output, fe = model(data)
|
||||||
|
y_pred.append(output)
|
||||||
|
y_true.append(label)
|
||||||
|
|
||||||
|
# 收集nfe数据
|
||||||
|
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
||||||
|
|
||||||
|
if args['real_value']:
|
||||||
|
# 只对输出维度进行反归一化
|
||||||
|
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
||||||
|
else:
|
||||||
|
y_pred = torch.cat(y_pred, dim=0)
|
||||||
|
y_true = torch.cat(y_true, dim=0)
|
||||||
|
|
||||||
|
# 计算每个时间步的指标
|
||||||
|
for t in range(y_true.shape[1]):
|
||||||
|
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||||
|
args['mae_thresh'], args['mape_thresh'])
|
||||||
|
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
|
|
||||||
|
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||||
|
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
||||||
|
Trainer._save_nfe_data_static(c, model, logger)
|
||||||
|
|
||||||
|
# 只在需要时生成可视化图片
|
||||||
|
if generate_viz:
|
||||||
|
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
||||||
|
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
|
||||||
|
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||||
|
target_node=1, num_samples=10, scalers=scalers)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _inverse_transform_output_static(output, args, scalers):
|
||||||
|
"""
|
||||||
|
静态方法:只对输出维度进行反归一化
|
||||||
|
"""
|
||||||
|
if not args['real_value']:
|
||||||
|
return output
|
||||||
|
|
||||||
|
# 获取输出维度的数量
|
||||||
|
output_dim = args['output_dim']
|
||||||
|
|
||||||
|
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||||
|
if output_dim <= len(scalers):
|
||||||
|
# 对每个输出特征分别进行反归一化
|
||||||
|
for feature_idx in range(output_dim):
|
||||||
|
if feature_idx < len(scalers):
|
||||||
|
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||||
|
for feature_idx in range(len(scalers)):
|
||||||
|
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||||
|
output[..., feature_idx:feature_idx+1]
|
||||||
|
)
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
|
||||||
|
"""
|
||||||
|
生成节点预测可视化图片
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred: 预测值
|
||||||
|
y_true: 真实值
|
||||||
|
logger: 日志记录器
|
||||||
|
save_dir: 保存目录
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 设置matplotlib配置,减少字体查找输出
|
||||||
|
matplotlib.set_loglevel('error') # 只显示错误信息
|
||||||
|
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if y_pred is None or y_true is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建pic文件夹
|
||||||
|
pic_dir = os.path.join(save_dir, 'pic')
|
||||||
|
os.makedirs(pic_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 固定生成10张图片
|
||||||
|
num_nodes_to_plot = 10
|
||||||
|
|
||||||
|
# 生成单个节点的详细图
|
||||||
|
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
|
||||||
|
for node_id in range(num_nodes_to_plot):
|
||||||
|
# 获取对应节点的数据
|
||||||
|
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||||
|
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||||
|
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻,指定节点,第一个特征
|
||||||
|
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
|
||||||
|
else:
|
||||||
|
# 如果数据不足10个节点,只处理实际存在的节点
|
||||||
|
if node_id >= y_pred.shape[-2]:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 取前500个时间步
|
||||||
|
max_steps = min(500, len(node_pred))
|
||||||
|
if max_steps <= 0:
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
node_pred_500 = node_pred[:max_steps]
|
||||||
|
node_true_500 = node_true[:max_steps]
|
||||||
|
|
||||||
|
# 创建时间轴
|
||||||
|
time_steps = np.arange(max_steps)
|
||||||
|
|
||||||
|
# 绘制对比图
|
||||||
|
plt.figure(figsize=(12, 6))
|
||||||
|
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
|
||||||
|
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
|
||||||
|
plt.xlabel('Time Steps')
|
||||||
|
plt.ylabel('Values')
|
||||||
|
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 保存图片,使用不同的命名
|
||||||
|
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 生成所有节点的对比图(前100个时间步,便于观察)
|
||||||
|
# 选择前100个时间步
|
||||||
|
plot_steps = min(100, y_pred.shape[0])
|
||||||
|
if plot_steps <= 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 创建子图
|
||||||
|
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for node_id in range(num_nodes_to_plot):
|
||||||
|
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||||
|
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||||
|
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
else:
|
||||||
|
# 如果数据不足10个节点,只处理实际存在的节点
|
||||||
|
if node_id >= y_pred.shape[-2]:
|
||||||
|
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||||
|
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
time_steps = np.arange(plot_steps)
|
||||||
|
|
||||||
|
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
|
||||||
|
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
|
||||||
|
axes[node_id].set_title(f'Node {node_id + 1}')
|
||||||
|
axes[node_id].grid(True, alpha=0.3)
|
||||||
|
axes[node_id].legend(fontsize=8)
|
||||||
|
|
||||||
|
if node_id >= 5: # 下面一行添加x轴标签
|
||||||
|
axes[node_id].set_xlabel('Time Steps')
|
||||||
|
if node_id % 5 == 0: # 左边一列添加y轴标签
|
||||||
|
axes[node_id].set_ylabel('Values')
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
|
||||||
|
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||||
|
target_node=1, num_samples=10, scalers=None):
|
||||||
|
"""
|
||||||
|
生成输入-输出样本比较图
|
||||||
|
|
||||||
|
Args:
|
||||||
|
y_pred: 预测值
|
||||||
|
y_true: 真实值
|
||||||
|
data_loader: 数据加载器,用于获取输入数据
|
||||||
|
logger: 日志记录器
|
||||||
|
save_dir: 保存目录
|
||||||
|
target_node: 目标节点ID(从1开始)
|
||||||
|
num_samples: 要比较的样本数量
|
||||||
|
scalers: 标准化器列表,用于反归一化输入数据
|
||||||
|
"""
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import numpy as np
|
||||||
|
import os
|
||||||
|
import matplotlib
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# 设置matplotlib配置
|
||||||
|
matplotlib.set_loglevel('error')
|
||||||
|
plt.rcParams['font.family'] = 'DejaVu Sans'
|
||||||
|
|
||||||
|
# 创建compare文件夹
|
||||||
|
compare_dir = os.path.join(save_dir, 'pic', 'compare')
|
||||||
|
os.makedirs(compare_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 获取输入数据
|
||||||
|
input_data = []
|
||||||
|
for batch_idx, (data, target) in enumerate(data_loader):
|
||||||
|
if batch_idx >= num_samples:
|
||||||
|
break
|
||||||
|
input_data.append(data.cpu().numpy())
|
||||||
|
|
||||||
|
if not input_data:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 获取目标节点的索引(从0开始)
|
||||||
|
node_idx = target_node - 1
|
||||||
|
|
||||||
|
# 检查节点索引是否有效
|
||||||
|
if node_idx >= y_pred.shape[-2]:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 为每个样本生成比较图
|
||||||
|
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
|
||||||
|
for sample_idx in range(min(num_samples, len(input_data))):
|
||||||
|
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features])
|
||||||
|
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch,所有时间步,目标节点,第一个特征
|
||||||
|
|
||||||
|
# 对输入数据进行反归一化
|
||||||
|
if scalers is not None and len(scalers) > 0:
|
||||||
|
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
|
||||||
|
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||||
|
|
||||||
|
# 获取对应的预测值和真实值
|
||||||
|
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon,目标节点,第一个特征
|
||||||
|
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
|
||||||
|
pbar.update(1)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 创建时间轴 - 输入和输出连续
|
||||||
|
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||||
|
|
||||||
|
# 创建合并的图形 - 输入和输出在同一个图中
|
||||||
|
plt.figure(figsize=(14, 8))
|
||||||
|
|
||||||
|
# 绘制完整的真实值曲线(输入 + 真实输出)
|
||||||
|
true_combined = np.concatenate([input_seq, true_seq])
|
||||||
|
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
|
||||||
|
linewidth=2.5, alpha=0.9, linestyle='-')
|
||||||
|
|
||||||
|
# 绘制预测值曲线(只绘制输出部分)
|
||||||
|
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||||
|
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
|
||||||
|
linewidth=2, alpha=0.8, linestyle='-')
|
||||||
|
|
||||||
|
# 添加垂直线分隔输入和输出
|
||||||
|
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
|
||||||
|
label='Input/Output Boundary')
|
||||||
|
|
||||||
|
# 设置图形属性
|
||||||
|
plt.xlabel('Time Steps')
|
||||||
|
plt.ylabel('Values')
|
||||||
|
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
|
||||||
|
plt.legend()
|
||||||
|
plt.grid(True, alpha=0.3)
|
||||||
|
|
||||||
|
# 调整布局
|
||||||
|
plt.tight_layout()
|
||||||
|
|
||||||
|
# 保存图片
|
||||||
|
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
# 生成汇总图(所有样本的预测值对比)
|
||||||
|
|
||||||
|
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||||
|
axes = axes.flatten()
|
||||||
|
|
||||||
|
for sample_idx in range(min(num_samples, len(input_data))):
|
||||||
|
if sample_idx >= 10: # 最多显示10个子图
|
||||||
|
break
|
||||||
|
|
||||||
|
ax = axes[sample_idx]
|
||||||
|
|
||||||
|
# 获取输入序列和预测值、真实值
|
||||||
|
input_seq = input_data[sample_idx][0, :, node_idx, 0]
|
||||||
|
if scalers is not None and len(scalers) > 0:
|
||||||
|
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||||
|
|
||||||
|
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||||
|
|
||||||
|
# 检查数据有效性
|
||||||
|
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
|
||||||
|
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
|
||||||
|
ha='center', va='center', transform=ax.transAxes)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 绘制对比图 - 输入和输出连续显示
|
||||||
|
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||||
|
true_combined = np.concatenate([input_seq, true_seq])
|
||||||
|
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||||
|
|
||||||
|
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
|
||||||
|
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
|
||||||
|
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
|
||||||
|
ax.set_title(f'Sample {sample_idx + 1}')
|
||||||
|
ax.grid(True, alpha=0.3)
|
||||||
|
ax.legend(fontsize=8)
|
||||||
|
|
||||||
|
if sample_idx >= 5: # 下面一行添加x轴标签
|
||||||
|
ax.set_xlabel('Time Steps')
|
||||||
|
if sample_idx % 5 == 0: # 左边一列添加y轴标签
|
||||||
|
ax.set_ylabel('Values')
|
||||||
|
|
||||||
|
# 隐藏多余的子图
|
||||||
|
for i in range(min(num_samples, len(input_data)), 10):
|
||||||
|
axes[i].set_visible(False)
|
||||||
|
|
||||||
|
plt.tight_layout()
|
||||||
|
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
|
||||||
|
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
def _save_nfe_data(self):
|
||||||
|
"""保存nfe数据到文件"""
|
||||||
|
if not self.res:
|
||||||
|
return
|
||||||
|
|
||||||
|
res = pd.concat(self.res, keys=self.keys)
|
||||||
|
res.index.names = ['epoch', 'iter']
|
||||||
|
|
||||||
|
# 获取模型配置参数
|
||||||
|
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
||||||
|
atol = getattr(self.model, 'atol', 1e-5)
|
||||||
|
rtol = getattr(self.model, 'rtol', 1e-5)
|
||||||
|
|
||||||
|
# 保存nfe数据
|
||||||
|
nfe_file = os.path.join(
|
||||||
|
self.logger.dir_path,
|
||||||
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||||
|
res.to_pickle(nfe_file)
|
||||||
|
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _compute_sampling_threshold(global_step, k):
|
||||||
|
return k / (k + math.exp(global_step / k))
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
from trainer.trainer import Trainer
|
from trainer.trainer import Trainer
|
||||||
|
from trainer.ode_trainer import Trainer as ode_trainer
|
||||||
|
|
||||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||||
lr_scheduler, kwargs):
|
lr_scheduler, kwargs):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
selected_Trainer = None
|
selected_Trainer = None
|
||||||
match model_name:
|
match model_name:
|
||||||
|
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
||||||
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||||
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
if selected_Trainer is None: raise NotImplementedError
|
if selected_Trainer is None: raise NotImplementedError
|
||||||
return selected_Trainer
|
return selected_Trainer
|
||||||
Loading…
Reference in New Issue