更新项目文件:添加数据选择器、模型选择器、训练器选择器和ODE训练器

This commit is contained in:
czzhangheng 2025-09-03 07:52:18 +08:00
parent df8c573f4c
commit 66a23ffbbb
24 changed files with 825 additions and 7 deletions

1
.gitignore vendored
View File

@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder. # option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/ #.idea/
STDEN/

8
.idea/.gitignore vendored Normal file
View File

@ -0,0 +1,8 @@
# 默认忽略的文件
/shelf/
/workspace.xml
# 基于编辑器的 HTTP 客户端请求
/httpRequests/
# Datasource local storage ignored files
/dataSources/
/dataSources.local.xml

12
.idea/Project-I.iml Normal file
View File

@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>

View File

@ -0,0 +1,27 @@
<component name="InspectionProjectProfileManager">
<profile version="1.0">
<option name="myName" value="Project Default" />
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
<Languages>
<language minSize="136" name="Python" />
</Languages>
</inspection_tool>
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
<option name="ignoredPackages">
<value>
<list size="8">
<item index="0" class="java.lang.String" itemvalue="argparse" />
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
<item index="5" class="java.lang.String" itemvalue="setuptools" />
<item index="6" class="java.lang.String" itemvalue="numpy" />
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
</list>
</value>
</option>
</inspection_tool>
</profile>
</component>

View File

@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>

7
.idea/misc.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.10" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
</project>

8
.idea/modules.xml Normal file
View File

@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
</modules>
</component>
</project>

7
.idea/vcs.xml Normal file
View File

@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="" vcs="Git" />
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
</component>
</project>

296
data/PEMS08/PEMS08.csv Executable file
View File

@ -0,0 +1,296 @@
from,to,cost
9,153,310.6
153,62,330.9
62,111,332.9
111,11,324.2
11,28,336.0
28,169,133.7
138,135,354.7
135,133,387.9
133,163,337.1
163,20,352.0
20,19,420.8
19,14,351.3
14,39,340.2
39,164,350.3
164,167,365.2
167,70,359.0
70,59,388.2
59,58,305.7
58,67,294.4
67,66,299.5
66,55,313.3
55,53,332.1
53,150,278.9
150,61,308.4
61,64,311.4
64,63,243.6
47,65,372.8
65,48,319.4
48,49,309.7
49,54,320.5
54,56,318.3
56,57,297.9
57,68,293.5
68,69,342.5
69,60,318.0
60,17,305.9
17,5,321.4
5,18,402.2
18,22,447.4
22,30,377.5
30,29,417.7
29,21,360.8
21,132,407.6
132,134,386.9
134,136,350.2
123,121,326.3
121,140,385.2
140,118,393.0
118,96,296.7
96,94,398.2
94,86,337.1
86,78,473.8
78,46,353.4
46,152,385.7
152,157,350.0
157,35,354.4
35,77,356.1
77,52,354.2
52,3,357.8
3,16,382.4
16,0,55.7
42,12,335.1
12,139,328.8
139,168,412.6
168,154,337.3
154,143,370.7
143,10,6.3
107,105,354.6
105,104,386.9
104,148,362.1
148,97,316.3
97,101,380.7
101,137,361.4
137,102,365.5
102,24,375.5
24,166,312.2
129,156,256.1
156,33,329.1
33,32,356.5
91,89,405.6
89,147,347.0
147,15,351.7
15,44,339.5
44,41,350.8
41,43,322.6
43,100,338.9
100,83,347.9
83,87,327.2
87,88,321.0
88,75,335.8
75,51,384.8
51,73,391.1
73,71,289.3
31,155,260.0
155,34,320.4
34,128,393.3
145,115,399.4
115,112,328.1
112,8,469.4
8,117,816.2
117,125,397.1
125,127,372.7
127,109,380.5
109,161,355.5
161,110,367.7
110,160,102.0
72,159,342.9
159,50,383.3
50,74,354.1
74,82,350.2
82,81,335.4
81,99,391.6
99,84,354.9
84,13,306.4
13,40,327.4
40,162,413.9
162,108,301.9
108,146,317.8
146,85,376.6
85,90,347.0
26,27,341.6
27,6,359.4
6,149,417.8
149,126,388.0
126,124,384.3
124,7,763.3
7,114,323.1
114,113,351.6
113,116,411.9
116,144,262.0
25,103,350.2
103,23,376.3
23,165,396.4
165,38,381.0
38,92,368.0
92,37,336.3
37,130,357.8
130,106,532.3
106,131,166.5
1,2,371.6
2,4,338.1
4,76,429.0
76,36,366.1
36,158,344.5
158,151,350.1
151,45,358.8
45,93,340.9
93,80,329.9
80,79,384.1
79,95,335.7
95,98,320.9
98,119,340.3
119,120,376.8
120,122,393.1
122,141,428.7
141,142,359.3
30,165,379.6
165,29,41.7
29,38,343.3
65,72,297.9
72,48,21.5
17,153,375.6
153,5,256.3
153,62,330.9
18,6,499.4
6,22,254.0
22,149,185.4
22,4,257.9
4,30,236.8
30,76,307.0
95,98,320.9
98,144,45.1
45,93,340.9
93,106,112.2
162,151,113.6
151,108,192.9
108,45,359.8
146,92,311.2
92,85,343.9
85,37,373.2
13,169,326.2
169,40,96.1
124,13,460.7
13,7,305.5
7,40,624.1
124,169,145.2
169,7,631.5
90,132,152.2
26,32,106.7
9,129,148.3
129,153,219.6
31,26,116.0
26,155,270.7
9,128,142.2
128,153,215.0
153,167,269.7
167,62,64.8
62,70,332.6
124,169,145.2
169,7,631.5
44,169,397.8
169,41,124.0
44,124,375.7
124,41,243.9
41,7,519.4
6,14,289.3
14,149,259.0
149,39,206.9
144,98,45.1
19,4,326.8
4,14,178.6
14,76,299.0
15,151,136.4
151,44,203.1
45,106,260.6
106,93,112.2
20,165,132.5
165,19,289.2
89,92,323.2
92,147,321.9
147,37,48.2
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
78,107,143.9
107,46,236.3
104,147,277.5
147,148,84.7
20,101,201.2
101,19,534.4
19,137,245.5
8,42,759.5
42,117,58.9
44,42,342.3
42,41,102.5
44,8,789.1
8,41,657.4
41,117,160.5
168,167,172.4
167,154,165.2
143,128,81.9
128,10,88.2
118,145,250.6
145,96,85.1
15,152,135.0
152,44,204.6
19,77,320.7
77,14,299.8
14,52,127.6
14,127,314.8
127,39,280.4
39,109,237.0
31,160,116.5
160,155,272.4
133,91,152.8
91,163,313.6
150,71,221.1
71,61,89.6
32,160,107.7
72,162,3274.4
162,13,554.5
162,40,413.9
65,72,297.9
72,48,21.5
13,42,319.8
42,40,40.7
8,42,759.5
42,117,58.9
8,13,450.3
13,117,378.5
117,40,64.0
46,162,391.6
162,152,115.3
152,108,191.4
104,108,375.9
108,148,311.6
148,146,80.0
21,90,396.9
90,132,152.2
101,29,252.3
29,137,110.7
77,22,353.8
22,52,227.8
52,30,186.6
127,18,425.2
18,109,439.1
109,22,135.5
168,17,232.7
17,154,294.2
154,5,166.3
78,107,143.9
107,46,236.3
118,145,250.6
145,96,85.1
1 from to cost
2 9 153 310.6
3 153 62 330.9
4 62 111 332.9
5 111 11 324.2
6 11 28 336.0
7 28 169 133.7
8 138 135 354.7
9 135 133 387.9
10 133 163 337.1
11 163 20 352.0
12 20 19 420.8
13 19 14 351.3
14 14 39 340.2
15 39 164 350.3
16 164 167 365.2
17 167 70 359.0
18 70 59 388.2
19 59 58 305.7
20 58 67 294.4
21 67 66 299.5
22 66 55 313.3
23 55 53 332.1
24 53 150 278.9
25 150 61 308.4
26 61 64 311.4
27 64 63 243.6
28 47 65 372.8
29 65 48 319.4
30 48 49 309.7
31 49 54 320.5
32 54 56 318.3
33 56 57 297.9
34 57 68 293.5
35 68 69 342.5
36 69 60 318.0
37 60 17 305.9
38 17 5 321.4
39 5 18 402.2
40 18 22 447.4
41 22 30 377.5
42 30 29 417.7
43 29 21 360.8
44 21 132 407.6
45 132 134 386.9
46 134 136 350.2
47 123 121 326.3
48 121 140 385.2
49 140 118 393.0
50 118 96 296.7
51 96 94 398.2
52 94 86 337.1
53 86 78 473.8
54 78 46 353.4
55 46 152 385.7
56 152 157 350.0
57 157 35 354.4
58 35 77 356.1
59 77 52 354.2
60 52 3 357.8
61 3 16 382.4
62 16 0 55.7
63 42 12 335.1
64 12 139 328.8
65 139 168 412.6
66 168 154 337.3
67 154 143 370.7
68 143 10 6.3
69 107 105 354.6
70 105 104 386.9
71 104 148 362.1
72 148 97 316.3
73 97 101 380.7
74 101 137 361.4
75 137 102 365.5
76 102 24 375.5
77 24 166 312.2
78 129 156 256.1
79 156 33 329.1
80 33 32 356.5
81 91 89 405.6
82 89 147 347.0
83 147 15 351.7
84 15 44 339.5
85 44 41 350.8
86 41 43 322.6
87 43 100 338.9
88 100 83 347.9
89 83 87 327.2
90 87 88 321.0
91 88 75 335.8
92 75 51 384.8
93 51 73 391.1
94 73 71 289.3
95 31 155 260.0
96 155 34 320.4
97 34 128 393.3
98 145 115 399.4
99 115 112 328.1
100 112 8 469.4
101 8 117 816.2
102 117 125 397.1
103 125 127 372.7
104 127 109 380.5
105 109 161 355.5
106 161 110 367.7
107 110 160 102.0
108 72 159 342.9
109 159 50 383.3
110 50 74 354.1
111 74 82 350.2
112 82 81 335.4
113 81 99 391.6
114 99 84 354.9
115 84 13 306.4
116 13 40 327.4
117 40 162 413.9
118 162 108 301.9
119 108 146 317.8
120 146 85 376.6
121 85 90 347.0
122 26 27 341.6
123 27 6 359.4
124 6 149 417.8
125 149 126 388.0
126 126 124 384.3
127 124 7 763.3
128 7 114 323.1
129 114 113 351.6
130 113 116 411.9
131 116 144 262.0
132 25 103 350.2
133 103 23 376.3
134 23 165 396.4
135 165 38 381.0
136 38 92 368.0
137 92 37 336.3
138 37 130 357.8
139 130 106 532.3
140 106 131 166.5
141 1 2 371.6
142 2 4 338.1
143 4 76 429.0
144 76 36 366.1
145 36 158 344.5
146 158 151 350.1
147 151 45 358.8
148 45 93 340.9
149 93 80 329.9
150 80 79 384.1
151 79 95 335.7
152 95 98 320.9
153 98 119 340.3
154 119 120 376.8
155 120 122 393.1
156 122 141 428.7
157 141 142 359.3
158 30 165 379.6
159 165 29 41.7
160 29 38 343.3
161 65 72 297.9
162 72 48 21.5
163 17 153 375.6
164 153 5 256.3
165 153 62 330.9
166 18 6 499.4
167 6 22 254.0
168 22 149 185.4
169 22 4 257.9
170 4 30 236.8
171 30 76 307.0
172 95 98 320.9
173 98 144 45.1
174 45 93 340.9
175 93 106 112.2
176 162 151 113.6
177 151 108 192.9
178 108 45 359.8
179 146 92 311.2
180 92 85 343.9
181 85 37 373.2
182 13 169 326.2
183 169 40 96.1
184 124 13 460.7
185 13 7 305.5
186 7 40 624.1
187 124 169 145.2
188 169 7 631.5
189 90 132 152.2
190 26 32 106.7
191 9 129 148.3
192 129 153 219.6
193 31 26 116.0
194 26 155 270.7
195 9 128 142.2
196 128 153 215.0
197 153 167 269.7
198 167 62 64.8
199 62 70 332.6
200 124 169 145.2
201 169 7 631.5
202 44 169 397.8
203 169 41 124.0
204 44 124 375.7
205 124 41 243.9
206 41 7 519.4
207 6 14 289.3
208 14 149 259.0
209 149 39 206.9
210 144 98 45.1
211 19 4 326.8
212 4 14 178.6
213 14 76 299.0
214 15 151 136.4
215 151 44 203.1
216 45 106 260.6
217 106 93 112.2
218 20 165 132.5
219 165 19 289.2
220 89 92 323.2
221 92 147 321.9
222 147 37 48.2
223 133 91 152.8
224 91 163 313.6
225 150 71 221.1
226 71 61 89.6
227 78 107 143.9
228 107 46 236.3
229 104 147 277.5
230 147 148 84.7
231 20 101 201.2
232 101 19 534.4
233 19 137 245.5
234 8 42 759.5
235 42 117 58.9
236 44 42 342.3
237 42 41 102.5
238 44 8 789.1
239 8 41 657.4
240 41 117 160.5
241 168 167 172.4
242 167 154 165.2
243 143 128 81.9
244 128 10 88.2
245 118 145 250.6
246 145 96 85.1
247 15 152 135.0
248 152 44 204.6
249 19 77 320.7
250 77 14 299.8
251 14 52 127.6
252 14 127 314.8
253 127 39 280.4
254 39 109 237.0
255 31 160 116.5
256 160 155 272.4
257 133 91 152.8
258 91 163 313.6
259 150 71 221.1
260 71 61 89.6
261 32 160 107.7
262 72 162 3274.4
263 162 13 554.5
264 162 40 413.9
265 65 72 297.9
266 72 48 21.5
267 13 42 319.8
268 42 40 40.7
269 8 42 759.5
270 42 117 58.9
271 8 13 450.3
272 13 117 378.5
273 117 40 64.0
274 46 162 391.6
275 162 152 115.3
276 152 108 191.4
277 104 108 375.9
278 108 148 311.6
279 148 146 80.0
280 21 90 396.9
281 90 132 152.2
282 101 29 252.3
283 29 137 110.7
284 77 22 353.8
285 22 52 227.8
286 52 30 186.6
287 127 18 425.2
288 18 109 439.1
289 109 22 135.5
290 168 17 232.7
291 17 154 294.2
292 154 5 166.3
293 78 107 143.9
294 107 46 236.3
295 118 145 250.6
296 145 96 85.1

BIN
data/PEMS08/PEMS08.npz Executable file

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -2,6 +2,7 @@ import numpy as np
import os import os
def load_dataset(config): def load_dataset(config):
dataset_name = config['basic']['dataset'] dataset_name = config['basic']['dataset']
node_num = config['data']['num_nodes'] node_num = config['data']['num_nodes']
input_dim = config['data']['input_dim'] input_dim = config['data']['input_dim']
@ -10,4 +11,8 @@ def load_dataset(config):
case 'EcoSolar': case 'EcoSolar':
data_path = os.path.join('./data/EcoSolar.npy') data_path = os.path.join('./data/EcoSolar.npy')
data = np.load(data_path)[:, :node_num, :input_dim] data = np.load(data_path)[:, :node_num, :input_dim]
case 'PEMS08':
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
data = np.load(data_path)['data'][:, :node_num, :input_dim]
return data return data

9
data/graph_loader.py Normal file
View File

@ -0,0 +1,9 @@
import numpy as np
def load_graph(config):
dataset_path = config['data']['graph_pkl_filename']
graph = np.load(dataset_path)
# 将inf值填充为0
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
return graph

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 100
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 100
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -0,0 +1,62 @@
basic:
dataset: PEMS08
device: cuda:0
mode: train
model: STDEN
seed: 2025
data:
add_day_in_week: true
add_time_in_day: true
batch_size: 32
column_wise: false
dataset_dir: data/PEMS08
days_per_week: 7
default_graph: true
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
horizon: 24
input_dim: 1
lag: 24
normalizer: std
num_nodes: 170
steps_per_day: 24
test_ratio: 0.2
tod: false
val_batch_size: 32
val_ratio: 0.2
model:
filter_type: default
gcn_step: 2
horizon: 12
input_dim: 1
l1_decay: 0
latent_dim: 4
n_traj_samples: 3
nfe: false
num_rnn_layers: 1
ode_method: dopri5
odeint_atol: 1.0e-05
odeint_rtol: 1.0e-05
output_dim: 1
recg_type: gru
rnn_units: 64
save_latent: false
seq_len: 12
train:
batch_size: 64
debug: false
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
loss: mae
lr_decay: false
lr_decay_rate: 0.3
lr_decay_step: 5,20,40,70
lr_init: 0.003
mae_thresh: None
mape_thresh: 0.001
max_grad_norm: 5
output_dim: 1
real_value: true
weight_decay: 0

View File

@ -3,8 +3,6 @@
时空数据深度学习预测项目主程序 时空数据深度学习预测项目主程序
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features) 专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
""" """
import os
from utils.args_reader import config_loader from utils.args_reader import config_loader
import utils.init as init import utils.init as init
import torch import torch

View File

@ -1,6 +1,8 @@
from models.STDEN.stden_model import STDENModel
def model_selector(config): def model_selector(config):
model_name = config['basic']['model'] model_name = config['basic']['model']
model = None model = None
match model_name:
case 'STDEN': model = STDENModel(config)
return model return model

View File

@ -2,6 +2,8 @@ import math
import os import os
import time import time
import copy import copy
import pandas as pd
import numpy as np
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -25,6 +27,10 @@ class Trainer:
self.best_path = os.path.join(logger.dir_path, 'best_model.pth') self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth') self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png') self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
# 用于收集nfe数据
self.c = []
self.res, self.keys = [], []
def _run_epoch(self, epoch, dataloader, mode): def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train': if mode == 'train':
@ -36,18 +42,29 @@ class Trainer:
total_loss = 0 total_loss = 0
epoch_time = time.time() epoch_time = time.time()
# 清空nfe数据收集
if mode == 'train':
self.c.clear()
with torch.set_grad_enabled(optimizer_step): with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader): for batch_idx, (data, target) in enumerate(dataloader):
label = target[..., :self.args['output_dim']] label = target[..., :self.args['output_dim']]
output = self.model(data).to(self.args['device']) output, fe = self.model(data)
if self.args['real_value']: if self.args['real_value']:
# 只对输出维度进行反归一化 # 只对输出维度进行反归一化
output = self._inverse_transform_output(output) output = self._inverse_transform_output(output)
loss = self.loss(output, label) loss = self.loss(output, label)
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.c.append([*fe, loss.item()])
# 记录FE信息
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
if optimizer_step and self.optimizer is not None: if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad() self.optimizer.zero_grad()
loss.backward() loss.backward()
@ -69,6 +86,12 @@ class Trainer:
avg_loss = total_loss / len(dataloader) avg_loss = total_loss / len(dataloader)
self.logger.logger.info( self.logger.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
# 收集nfe数据仅在训练模式下
if mode == 'train':
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
self.keys.append(epoch)
return avg_loss return avg_loss
def _inverse_transform_output(self, output): def _inverse_transform_output(self, output):
@ -143,6 +166,10 @@ class Trainer:
best_test_model = copy.deepcopy(self.model.state_dict()) best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
# 保存nfe数据如果启用
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
self._save_nfe_data()
if not self.args['debug']: if not self.args['debug']:
torch.save(best_model, self.best_path) torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path) torch.save(best_test_model, self.best_test_path)
@ -164,17 +191,23 @@ class Trainer:
if path: if path:
checkpoint = torch.load(path) checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict']) model.load_state_dict(checkpoint['state_dict'])
model.to(args.device) model.to(args['device'])
model.eval() model.eval()
y_pred, y_true = [], [] y_pred, y_true = [], []
# 用于收集nfe数据
c = []
with torch.no_grad(): with torch.no_grad():
for data, target in data_loader: for data, target in data_loader:
label = target[..., :args['output_dim']] label = target[..., :args['output_dim']]
output = model(data) output, fe = model(data)
y_pred.append(output) y_pred.append(output)
y_true.append(label) y_true.append(label)
# 收集nfe数据
c.append([*fe, 0.0]) # 测试时没有loss设为0
if args['real_value']: if args['real_value']:
# 只对输出维度进行反归一化 # 只对输出维度进行反归一化
@ -192,6 +225,10 @@ class Trainer:
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 保存nfe数据如果启用
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
Trainer._save_nfe_data_static(c, model, logger)
# 只在需要时生成可视化图片 # 只在需要时生成可视化图片
if generate_viz: if generate_viz:
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs' save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
@ -514,6 +551,26 @@ class Trainer:
plt.savefig(summary_path, dpi=300, bbox_inches='tight') plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close() plt.close()
def _save_nfe_data(self):
"""保存nfe数据到文件"""
if not self.res:
return
res = pd.concat(self.res, keys=self.keys)
res.index.names = ['epoch', 'iter']
# 获取模型配置参数
filter_type = getattr(self.model, 'filter_type', 'unknown')
atol = getattr(self.model, 'atol', 1e-5)
rtol = getattr(self.model, 'rtol', 1e-5)
# 保存nfe数据
nfe_file = os.path.join(
self.logger.dir_path,
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
res.to_pickle(nfe_file)
self.logger.logger.info(f"NFE data saved to {nfe_file}")
@staticmethod @staticmethod
def _compute_sampling_threshold(global_step, k): def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k)) return k / (k + math.exp(global_step / k))

View File

@ -1,11 +1,14 @@
from trainer.trainer import Trainer from trainer.trainer import Trainer
from trainer.ode_trainer import Trainer as ode_trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs): lr_scheduler, kwargs):
model_name = config['basic']['model'] model_name = config['basic']['model']
selected_Trainer = None selected_Trainer = None
match model_name: match model_name:
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler, lr_scheduler)
case _: selected_Trainer = Trainer(config, model, loss, optimizer, case _: selected_Trainer = Trainer(config, model, loss, optimizer,
train_loader, val_loader, test_loader, scaler,lr_scheduler) train_loader, val_loader, test_loader, scaler, lr_scheduler)
if selected_Trainer is None: raise NotImplementedError if selected_Trainer is None: raise NotImplementedError
return selected_Trainer return selected_Trainer