Compare commits

...

2 Commits

24 changed files with 1341 additions and 4 deletions

1
.gitignore vendored
View File

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
STDEN/

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/Project-I.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -0,0 +1,27 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="136" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="8">
<item index="0" class="java.lang.String" itemvalue="argparse" />
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
<item index="5" class="java.lang.String" itemvalue="setuptools" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
</modules>
</component>
</project>

7
.idea/vcs.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

296
data/PEMS08/PEMS08.csv Executable file
View File

@ -0,0 +1,296 @@
from,to,cost
9,153,310.6
153,62,330.9
62,111,332.9
111,11,324.2
11,28,336.0
28,169,133.7
138,135,354.7
135,133,387.9
133,163,337.1
163,20,352.0
20,19,420.8
19,14,351.3
14,39,340.2
39,164,350.3
164,167,365.2
167,70,359.0
70,59,388.2
59,58,305.7
58,67,294.4
67,66,299.5
66,55,313.3
55,53,332.1
53,150,278.9
150,61,308.4
61,64,311.4
64,63,243.6
47,65,372.8
65,48,319.4
48,49,309.7
49,54,320.5
54,56,318.3
56,57,297.9
57,68,293.5
68,69,342.5
69,60,318.0
60,17,305.9
17,5,321.4
5,18,402.2
18,22,447.4
22,30,377.5
30,29,417.7
29,21,360.8
21,132,407.6
132,134,386.9
134,136,350.2
123,121,326.3
121,140,385.2
140,118,393.0
118,96,296.7
96,94,398.2
94,86,337.1
86,78,473.8
78,46,353.4
46,152,385.7
152,157,350.0
157,35,354.4
35,77,356.1
77,52,354.2
52,3,357.8
3,16,382.4
16,0,55.7
42,12,335.1
12,139,328.8
139,168,412.6
168,154,337.3
154,143,370.7
143,10,6.3
107,105,354.6
105,104,386.9
104,148,362.1
148,97,316.3
97,101,380.7
101,137,361.4
137,102,365.5
102,24,375.5
24,166,312.2
129,156,256.1
156,33,329.1
33,32,356.5
91,89,405.6
89,147,347.0
147,15,351.7
15,44,339.5
44,41,350.8
41,43,322.6
43,100,338.9
100,83,347.9
83,87,327.2
87,88,321.0
88,75,335.8
75,51,384.8
51,73,391.1
73,71,289.3
31,155,260.0
155,34,320.4
34,128,393.3
145,115,399.4
115,112,328.1
112,8,469.4
8,117,816.2
117,125,397.1
125,127,372.7
127,109,380.5
109,161,355.5
161,110,367.7
110,160,102.0
72,159,342.9
159,50,383.3
50,74,354.1
74,82,350.2
82,81,335.4
81,99,391.6
99,84,354.9
84,13,306.4
13,40,327.4
40,162,413.9
162,108,301.9
108,146,317.8
146,85,376.6
85,90,347.0
26,27,341.6
27,6,359.4
6,149,417.8
149,126,388.0
126,124,384.3
124,7,763.3
7,114,323.1
114,113,351.6
113,116,411.9
116,144,262.0
25,103,350.2
103,23,376.3
23,165,396.4
165,38,381.0
38,92,368.0
92,37,336.3
37,130,357.8
130,106,532.3
106,131,166.5
1,2,371.6
2,4,338.1
4,76,429.0
76,36,366.1
36,158,344.5
158,151,350.1
151,45,358.8
45,93,340.9
93,80,329.9
80,79,384.1
79,95,335.7
95,98,320.9
98,119,340.3
119,120,376.8
120,122,393.1
122,141,428.7
141,142,359.3
30,165,379.6
165,29,41.7
29,38,343.3
65,72,297.9
72,48,21.5
17,153,375.6
153,5,256.3
153,62,330.9
18,6,499.4
6,22,254.0
22,149,185.4
22,4,257.9
4,30,236.8
30,76,307.0
95,98,320.9
98,144,45.1
45,93,340.9
93,106,112.2
162,151,113.6
151,108,192.9
108,45,359.8
146,92,311.2
92,85,343.9
85,37,373.2
13,169,326.2
169,40,96.1
124,13,460.7
13,7,305.5
7,40,624.1
124,169,145.2
169,7,631.5
90,132,152.2
26,32,106.7
9,129,148.3
129,153,219.6
31,26,116.0
26,155,270.7
9,128,142.2
128,153,215.0
153,167,269.7
167,62,64.8
62,70,332.6
124,169,145.2
169,7,631.5
44,169,397.8
169,41,124.0
44,124,375.7
124,41,243.9
41,7,519.4
6,14,289.3
14,149,259.0
149,39,206.9
144,98,45.1
19,4,326.8
4,14,178.6
14,76,299.0
15,151,136.4
151,44,203.1
45,106,260.6
106,93,112.2
20,165,132.5
165,19,289.2
89,92,323.2
92,147,321.9
147,37,48.2
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
78,107,143.9
107,46,236.3
104,147,277.5
147,148,84.7
20,101,201.2
101,19,534.4
19,137,245.5
8,42,759.5
42,117,58.9
44,42,342.3
42,41,102.5
44,8,789.1
8,41,657.4
41,117,160.5
168,167,172.4
167,154,165.2
143,128,81.9
128,10,88.2
118,145,250.6
145,96,85.1
15,152,135.0
152,44,204.6
19,77,320.7
77,14,299.8
14,52,127.6
14,127,314.8
127,39,280.4
39,109,237.0
31,160,116.5
160,155,272.4
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
32,160,107.7
72,162,3274.4
162,13,554.5
162,40,413.9
65,72,297.9
72,48,21.5
13,42,319.8
42,40,40.7
8,42,759.5
42,117,58.9
8,13,450.3
13,117,378.5
117,40,64.0
46,162,391.6
162,152,115.3
152,108,191.4
104,108,375.9
108,148,311.6
148,146,80.0
21,90,396.9
90,132,152.2
101,29,252.3
29,137,110.7
77,22,353.8
22,52,227.8
52,30,186.6
127,18,425.2
18,109,439.1
109,22,135.5
168,17,232.7
17,154,294.2
154,5,166.3
78,107,143.9
107,46,236.3
118,145,250.6
145,96,85.1
1 from to cost
2 9 153 310.6
3 153 62 330.9
4 62 111 332.9
5 111 11 324.2
6 11 28 336.0
7 28 169 133.7
8 138 135 354.7
9 135 133 387.9
10 133 163 337.1
11 163 20 352.0
12 20 19 420.8
13 19 14 351.3
14 14 39 340.2
15 39 164 350.3
16 164 167 365.2
17 167 70 359.0
18 70 59 388.2
19 59 58 305.7
20 58 67 294.4
21 67 66 299.5
22 66 55 313.3
23 55 53 332.1
24 53 150 278.9
25 150 61 308.4
26 61 64 311.4
27 64 63 243.6
28 47 65 372.8
29 65 48 319.4
30 48 49 309.7
31 49 54 320.5
32 54 56 318.3
33 56 57 297.9
34 57 68 293.5
35 68 69 342.5
36 69 60 318.0
37 60 17 305.9
38 17 5 321.4
39 5 18 402.2
40 18 22 447.4
41 22 30 377.5
42 30 29 417.7
43 29 21 360.8
44 21 132 407.6
45 132 134 386.9
46 134 136 350.2
47 123 121 326.3
48 121 140 385.2
49 140 118 393.0
50 118 96 296.7
51 96 94 398.2
52 94 86 337.1
53 86 78 473.8
54 78 46 353.4
55 46 152 385.7
56 152 157 350.0
57 157 35 354.4
58 35 77 356.1
59 77 52 354.2
60 52 3 357.8
61 3 16 382.4
62 16 0 55.7
63 42 12 335.1
64 12 139 328.8
65 139 168 412.6
66 168 154 337.3
67 154 143 370.7
68 143 10 6.3
69 107 105 354.6
70 105 104 386.9
71 104 148 362.1
72 148 97 316.3
73 97 101 380.7
74 101 137 361.4
75 137 102 365.5
76 102 24 375.5
77 24 166 312.2
78 129 156 256.1
79 156 33 329.1
80 33 32 356.5
81 91 89 405.6
82 89 147 347.0
83 147 15 351.7
84 15 44 339.5
85 44 41 350.8
86 41 43 322.6
87 43 100 338.9
88 100 83 347.9
89 83 87 327.2
90 87 88 321.0
91 88 75 335.8
92 75 51 384.8
93 51 73 391.1
94 73 71 289.3
95 31 155 260.0
96 155 34 320.4
97 34 128 393.3
98 145 115 399.4
99 115 112 328.1
100 112 8 469.4
101 8 117 816.2
102 117 125 397.1
103 125 127 372.7
104 127 109 380.5
105 109 161 355.5
106 161 110 367.7
107 110 160 102.0
108 72 159 342.9
109 159 50 383.3
110 50 74 354.1
111 74 82 350.2
112 82 81 335.4
113 81 99 391.6
114 99 84 354.9
115 84 13 306.4
116 13 40 327.4
117 40 162 413.9
118 162 108 301.9
119 108 146 317.8
120 146 85 376.6
121 85 90 347.0
122 26 27 341.6
123 27 6 359.4
124 6 149 417.8
125 149 126 388.0
126 126 124 384.3
127 124 7 763.3
128 7 114 323.1
129 114 113 351.6
130 113 116 411.9
131 116 144 262.0
132 25 103 350.2
133 103 23 376.3
134 23 165 396.4
135 165 38 381.0
136 38 92 368.0
137 92 37 336.3
138 37 130 357.8
139 130 106 532.3
140 106 131 166.5
141 1 2 371.6
142 2 4 338.1
143 4 76 429.0
144 76 36 366.1
145 36 158 344.5
146 158 151 350.1
147 151 45 358.8
148 45 93 340.9
149 93 80 329.9
150 80 79 384.1
151 79 95 335.7
152 95 98 320.9
153 98 119 340.3
154 119 120 376.8
155 120 122 393.1
156 122 141 428.7
157 141 142 359.3
158 30 165 379.6
159 165 29 41.7
160 29 38 343.3
161 65 72 297.9
162 72 48 21.5
163 17 153 375.6
164 153 5 256.3
165 153 62 330.9
166 18 6 499.4
167 6 22 254.0
168 22 149 185.4
169 22 4 257.9
170 4 30 236.8
171 30 76 307.0
172 95 98 320.9
173 98 144 45.1
174 45 93 340.9
175 93 106 112.2
176 162 151 113.6
177 151 108 192.9
178 108 45 359.8
179 146 92 311.2
180 92 85 343.9
181 85 37 373.2
182 13 169 326.2
183 169 40 96.1
184 124 13 460.7
185 13 7 305.5
186 7 40 624.1
187 124 169 145.2
188 169 7 631.5
189 90 132 152.2
190 26 32 106.7
191 9 129 148.3
192 129 153 219.6
193 31 26 116.0
194 26 155 270.7
195 9 128 142.2
196 128 153 215.0
197 153 167 269.7
198 167 62 64.8
199 62 70 332.6
200 124 169 145.2
201 169 7 631.5
202 44 169 397.8
203 169 41 124.0
204 44 124 375.7
205 124 41 243.9
206 41 7 519.4
207 6 14 289.3
208 14 149 259.0
209 149 39 206.9
210 144 98 45.1
211 19 4 326.8
212 4 14 178.6
213 14 76 299.0
214 15 151 136.4
215 151 44 203.1
216 45 106 260.6
217 106 93 112.2
218 20 165 132.5
219 165 19 289.2
220 89 92 323.2
221 92 147 321.9
222 147 37 48.2
223 133 91 152.8
224 91 163 313.6
225 150 71 221.1
226 71 61 89.6
227 78 107 143.9
228 107 46 236.3
229 104 147 277.5
230 147 148 84.7
231 20 101 201.2
232 101 19 534.4
233 19 137 245.5
234 8 42 759.5
235 42 117 58.9
236 44 42 342.3
237 42 41 102.5
238 44 8 789.1
239 8 41 657.4
240 41 117 160.5
241 168 167 172.4
242 167 154 165.2
243 143 128 81.9
244 128 10 88.2
245 118 145 250.6
246 145 96 85.1
247 15 152 135.0
248 152 44 204.6
249 19 77 320.7
250 77 14 299.8
251 14 52 127.6
252 14 127 314.8
253 127 39 280.4
254 39 109 237.0
255 31 160 116.5
256 160 155 272.4
257 133 91 152.8
258 91 163 313.6
259 150 71 221.1
260 71 61 89.6
261 32 160 107.7
262 72 162 3274.4
263 162 13 554.5
264 162 40 413.9
265 65 72 297.9
266 72 48 21.5
267 13 42 319.8
268 42 40 40.7
269 8 42 759.5
270 42 117 58.9
271 8 13 450.3
272 13 117 378.5
273 117 40 64.0
274 46 162 391.6
275 162 152 115.3
276 152 108 191.4
277 104 108 375.9
278 108 148 311.6
279 148 146 80.0
280 21 90 396.9
281 90 132 152.2
282 101 29 252.3
283 29 137 110.7
284 77 22 353.8
285 22 52 227.8
286 52 30 186.6
287 127 18 425.2
288 18 109 439.1
289 109 22 135.5
290 168 17 232.7
291 17 154 294.2
292 154 5 166.3
293 78 107 143.9
294 107 46 236.3
295 118 145 250.6
296 145 96 85.1

BIN
data/PEMS08/PEMS08.npz Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,6 +2,7 @@ import numpy as np
import os
def load_dataset(config):
dataset_name = config['basic']['dataset']
node_num = config['data']['num_nodes']
input_dim = config['data']['input_dim']
@ -10,4 +11,8 @@ def load_dataset(config):
case 'EcoSolar':
data_path = os.path.join('./data/EcoSolar.npy')
data = np.load(data_path)[:, :node_num, :input_dim]
case 'PEMS08':
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
data = np.load(data_path)['data'][:, :node_num, :input_dim]
return data

9
data/graph_loader.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
def load_graph(config):
dataset_path = config['data']['graph_pkl_filename']
graph = np.load(dataset_path)
# 将inf值填充为0
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
return graph

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 100
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 100
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -3,8 +3,6 @@
时空数据深度学习预测项目主程序
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
"""
import os
from utils.args_reader import config_loader
import utils.init as init
import torch

View File

@ -1,6 +1,8 @@
from models.STDEN.stden_model import STDENModel
def model_selector(config):
model_name = config['basic']['model']
model = None
match model_name:
case 'STDEN': model = STDENModel(config)
return model

576
trainer/ode_trainer.py Normal file
View File

@ -0,0 +1,576 @@
import math
import os
import time
import copy
import pandas as pd
import numpy as np
from tqdm import tqdm
import torch
class Trainer:
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
scalers, logger, lr_scheduler=None):
self.model = model
self.loss = loss
self.optimizer = optimizer
self.train_loader = train_loader
self.val_loader = val_loader
self.test_loader = test_loader
self.scalers = scalers # 现在是多个标准化器的列表
self.args = config['train']
self.logger = logger
self.args['device'] = config['basic']['device']
self.lr_scheduler = lr_scheduler
self.train_per_epoch = len(train_loader)
self.val_per_epoch = len(val_loader) if val_loader else 0
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
# 用于收集nfe数据
self.c = []
self.res, self.keys = [], []
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
self.model.train()
optimizer_step = True
else:
self.model.eval()
optimizer_step = False
total_loss = 0
epoch_time = time.time()
# 清空nfe数据收集
if mode == 'train':
self.c.clear()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
label = target[..., :self.args['output_dim']]
output, fe = self.model(data)
if self.args['real_value']:
# 只对输出维度进行反归一化
output = self._inverse_transform_output(output)
loss = self.loss(output, label)
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.c.append([*fe, loss.item()])
# 记录FE信息
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
if self.args['grad_norm']:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
self.optimizer.step()
total_loss += loss.item()
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
self.logger.info(
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
# 更新 tqdm 的进度
pbar.update(1)
pbar.set_postfix(loss=loss.item())
avg_loss = total_loss / len(dataloader)
self.logger.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
self.keys.append(epoch)
return avg_loss
def _inverse_transform_output(self, output):
"""
只对输出维度进行反归一化
假设输出数据形状为 [batch, horizon, nodes, features]
只对前output_dim个特征进行反归一化
"""
if not self.args['real_value']:
return output
# 获取输出维度的数量
output_dim = self.args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(self.scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(self.scalers):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(self.scalers)):
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
def train_epoch(self, epoch):
return self._run_epoch(epoch, self.train_loader, 'train')
def val_epoch(self, epoch):
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
def test_epoch(self, epoch):
return self._run_epoch(epoch, self.test_loader, 'test')
def train(self):
best_model, best_test_model = None, None
best_loss, best_test_loss = float('inf'), float('inf')
not_improved_count = 0
self.logger.logger.info("Training process started")
for epoch in range(1, self.args['epochs'] + 1):
train_epoch_loss = self.train_epoch(epoch)
val_epoch_loss = self.val_epoch(epoch)
test_epoch_loss = self.test_epoch(epoch)
if train_epoch_loss > 1e6:
self.logger.logger.warning('Gradient explosion detected. Ending...')
break
if val_epoch_loss < best_loss:
best_loss = val_epoch_loss
not_improved_count = 0
best_model = copy.deepcopy(self.model.state_dict())
torch.save(best_model, self.best_path)
self.logger.logger.info('Best validation model saved!')
else:
not_improved_count += 1
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
self.logger.logger.info(
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
break
if test_epoch_loss < best_test_loss:
best_test_loss = test_epoch_loss
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
# 保存nfe数据如果启用
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
self._save_nfe_data()
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
self._finalize_training(best_model, best_test_model)
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
self.logger.logger.info("Testing on best validation model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
self.model.load_state_dict(best_test_model)
self.logger.logger.info("Testing on best test model")
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
@staticmethod
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
model.to(args['device'])
model.eval()
y_pred, y_true = [], []
# 用于收集nfe数据
c = []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
output, fe = model(data)
y_pred.append(output)
y_true.append(label)
# 收集nfe数据
c.append([*fe, 0.0]) # 测试时没有loss设为0
if args['real_value']:
# 只对输出维度进行反归一化
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
# 计算每个时间步的指标
for t in range(y_true.shape[1]):
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 保存nfe数据如果启用
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
Trainer._save_nfe_data_static(c, model, logger)
# 只在需要时生成可视化图片
if generate_viz:
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=scalers)
@staticmethod
def _inverse_transform_output_static(output, args, scalers):
"""
静态方法只对输出维度进行反归一化
"""
if not args['real_value']:
return output
# 获取输出维度的数量
output_dim = args['output_dim']
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
if output_dim <= len(scalers):
# 对每个输出特征分别进行反归一化
for feature_idx in range(output_dim):
if feature_idx < len(scalers):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
else:
# 如果输出特征数大于标准化器数量只对前len(scalers)个特征进行反归一化
for feature_idx in range(len(scalers)):
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
output[..., feature_idx:feature_idx+1]
)
return output
@staticmethod
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
"""
生成节点预测可视化图片
Args:
y_pred: 预测值
y_true: 真实值
logger: 日志记录器
save_dir: 保存目录
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置减少字体查找输出
matplotlib.set_loglevel('error') # 只显示错误信息
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
# 检查数据有效性
if y_pred is None or y_true is None:
return
# 创建pic文件夹
pic_dir = os.path.join(save_dir, 'pic')
os.makedirs(pic_dir, exist_ok=True)
# 固定生成10张图片
num_nodes_to_plot = 10
# 生成单个节点的详细图
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
for node_id in range(num_nodes_to_plot):
# 获取对应节点的数据
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻指定节点第一个特征
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
pbar.update(1)
continue
else:
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
pbar.update(1)
continue
# 取前500个时间步
max_steps = min(500, len(node_pred))
if max_steps <= 0:
pbar.update(1)
continue
node_pred_500 = node_pred[:max_steps]
node_true_500 = node_true[:max_steps]
# 创建时间轴
time_steps = np.arange(max_steps)
# 绘制对比图
plt.figure(figsize=(12, 6))
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
plt.legend()
plt.grid(True, alpha=0.3)
# 保存图片,使用不同的命名
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成所有节点的对比图前100个时间步便于观察
# 选择前100个时间步
plot_steps = min(100, y_pred.shape[0])
if plot_steps <= 0:
return
# 创建子图
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for node_id in range(num_nodes_to_plot):
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
# 数据格式: [time_step, seq_len, num_node, dim]
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
else:
# 如果数据不足10个节点只处理实际存在的节点
if node_id >= y_pred.shape[-2]:
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
else:
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(node_pred).any() or np.isnan(node_true).any():
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
ha='center', va='center', transform=axes[node_id].transAxes)
continue
time_steps = np.arange(plot_steps)
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
axes[node_id].set_title(f'Node {node_id + 1}')
axes[node_id].grid(True, alpha=0.3)
axes[node_id].legend(fontsize=8)
if node_id >= 5: # 下面一行添加x轴标签
axes[node_id].set_xlabel('Time Steps')
if node_id % 5 == 0: # 左边一列添加y轴标签
axes[node_id].set_ylabel('Values')
plt.tight_layout()
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
@staticmethod
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
target_node=1, num_samples=10, scalers=None):
"""
生成输入-输出样本比较图
Args:
y_pred: 预测值
y_true: 真实值
data_loader: 数据加载器用于获取输入数据
logger: 日志记录器
save_dir: 保存目录
target_node: 目标节点ID从1开始
num_samples: 要比较的样本数量
scalers: 标准化器列表用于反归一化输入数据
"""
import matplotlib.pyplot as plt
import numpy as np
import os
import matplotlib
from tqdm import tqdm
# 设置matplotlib配置
matplotlib.set_loglevel('error')
plt.rcParams['font.family'] = 'DejaVu Sans'
# 创建compare文件夹
compare_dir = os.path.join(save_dir, 'pic', 'compare')
os.makedirs(compare_dir, exist_ok=True)
# 获取输入数据
input_data = []
for batch_idx, (data, target) in enumerate(data_loader):
if batch_idx >= num_samples:
break
input_data.append(data.cpu().numpy())
if not input_data:
return
# 获取目标节点的索引从0开始
node_idx = target_node - 1
# 检查节点索引是否有效
if node_idx >= y_pred.shape[-2]:
return
# 为每个样本生成比较图
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
for sample_idx in range(min(num_samples, len(input_data))):
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features]
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch所有时间步目标节点第一个特征
# 对输入数据进行反归一化
if scalers is not None and len(scalers) > 0:
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
# 获取对应的预测值和真实值
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon目标节点第一个特征
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
pbar.update(1)
continue
# 创建时间轴 - 输入和输出连续
total_time = np.arange(len(input_seq) + len(pred_seq))
# 创建合并的图形 - 输入和输出在同一个图中
plt.figure(figsize=(14, 8))
# 绘制完整的真实值曲线(输入 + 真实输出)
true_combined = np.concatenate([input_seq, true_seq])
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
linewidth=2.5, alpha=0.9, linestyle='-')
# 绘制预测值曲线(只绘制输出部分)
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
linewidth=2, alpha=0.8, linestyle='-')
# 添加垂直线分隔输入和输出
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
label='Input/Output Boundary')
# 设置图形属性
plt.xlabel('Time Steps')
plt.ylabel('Values')
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
plt.legend()
plt.grid(True, alpha=0.3)
# 调整布局
plt.tight_layout()
# 保存图片
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.close()
pbar.update(1)
# 生成汇总图(所有样本的预测值对比)
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
axes = axes.flatten()
for sample_idx in range(min(num_samples, len(input_data))):
if sample_idx >= 10: # 最多显示10个子图
break
ax = axes[sample_idx]
# 获取输入序列和预测值、真实值
input_seq = input_data[sample_idx][0, :, node_idx, 0]
if scalers is not None and len(scalers) > 0:
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
# 检查数据有效性
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
ha='center', va='center', transform=ax.transAxes)
continue
# 绘制对比图 - 输入和输出连续显示
total_time = np.arange(len(input_seq) + len(pred_seq))
true_combined = np.concatenate([input_seq, true_seq])
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
ax.set_title(f'Sample {sample_idx + 1}')
ax.grid(True, alpha=0.3)
ax.legend(fontsize=8)
if sample_idx >= 5: # 下面一行添加x轴标签
ax.set_xlabel('Time Steps')
if sample_idx % 5 == 0: # 左边一列添加y轴标签
ax.set_ylabel('Values')
# 隐藏多余的子图
for i in range(min(num_samples, len(input_data)), 10):
axes[i].set_visible(False)
plt.tight_layout()
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
def _save_nfe_data(self):
"""保存nfe数据到文件"""
if not self.res:
return
res = pd.concat(self.res, keys=self.keys)
res.index.names = ['epoch', 'iter']
# 获取模型配置参数
filter_type = getattr(self.model, 'filter_type', 'unknown')
atol = getattr(self.model, 'atol', 1e-5)
rtol = getattr(self.model, 'rtol', 1e-5)
# 保存nfe数据
nfe_file = os.path.join(
self.logger.dir_path,
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)
self.logger.logger.info(f"NFE data saved to {nfe_file}")
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))

View File

@ -1,11 +1,14 @@
from trainer.trainer import Trainer
from trainer.ode_trainer import Trainer as ode_trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs):
model_name = config['basic']['model']
selected_Trainer = None
match model_name:
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler)
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,lr_scheduler)
train_loader, val_loader, test_loader, scaler, lr_scheduler)
if selected_Trainer is None: raise NotImplementedError
return selected_Trainer