Compare commits
2 Commits
19a02ba7ae
...
66a23ffbbb
| Author | SHA1 | Date |
|---|---|---|
|
|
66a23ffbbb | |
|
|
df8c573f4c |
|
|
@ -160,3 +160,4 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
STDEN/
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<Languages>
|
||||
<language minSize="136" name="Python" />
|
||||
</Languages>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="8">
|
||||
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
||||
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
||||
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
||||
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
||||
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
||||
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
||||
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
from,to,cost
|
||||
9,153,310.6
|
||||
153,62,330.9
|
||||
62,111,332.9
|
||||
111,11,324.2
|
||||
11,28,336.0
|
||||
28,169,133.7
|
||||
138,135,354.7
|
||||
135,133,387.9
|
||||
133,163,337.1
|
||||
163,20,352.0
|
||||
20,19,420.8
|
||||
19,14,351.3
|
||||
14,39,340.2
|
||||
39,164,350.3
|
||||
164,167,365.2
|
||||
167,70,359.0
|
||||
70,59,388.2
|
||||
59,58,305.7
|
||||
58,67,294.4
|
||||
67,66,299.5
|
||||
66,55,313.3
|
||||
55,53,332.1
|
||||
53,150,278.9
|
||||
150,61,308.4
|
||||
61,64,311.4
|
||||
64,63,243.6
|
||||
47,65,372.8
|
||||
65,48,319.4
|
||||
48,49,309.7
|
||||
49,54,320.5
|
||||
54,56,318.3
|
||||
56,57,297.9
|
||||
57,68,293.5
|
||||
68,69,342.5
|
||||
69,60,318.0
|
||||
60,17,305.9
|
||||
17,5,321.4
|
||||
5,18,402.2
|
||||
18,22,447.4
|
||||
22,30,377.5
|
||||
30,29,417.7
|
||||
29,21,360.8
|
||||
21,132,407.6
|
||||
132,134,386.9
|
||||
134,136,350.2
|
||||
123,121,326.3
|
||||
121,140,385.2
|
||||
140,118,393.0
|
||||
118,96,296.7
|
||||
96,94,398.2
|
||||
94,86,337.1
|
||||
86,78,473.8
|
||||
78,46,353.4
|
||||
46,152,385.7
|
||||
152,157,350.0
|
||||
157,35,354.4
|
||||
35,77,356.1
|
||||
77,52,354.2
|
||||
52,3,357.8
|
||||
3,16,382.4
|
||||
16,0,55.7
|
||||
42,12,335.1
|
||||
12,139,328.8
|
||||
139,168,412.6
|
||||
168,154,337.3
|
||||
154,143,370.7
|
||||
143,10,6.3
|
||||
107,105,354.6
|
||||
105,104,386.9
|
||||
104,148,362.1
|
||||
148,97,316.3
|
||||
97,101,380.7
|
||||
101,137,361.4
|
||||
137,102,365.5
|
||||
102,24,375.5
|
||||
24,166,312.2
|
||||
129,156,256.1
|
||||
156,33,329.1
|
||||
33,32,356.5
|
||||
91,89,405.6
|
||||
89,147,347.0
|
||||
147,15,351.7
|
||||
15,44,339.5
|
||||
44,41,350.8
|
||||
41,43,322.6
|
||||
43,100,338.9
|
||||
100,83,347.9
|
||||
83,87,327.2
|
||||
87,88,321.0
|
||||
88,75,335.8
|
||||
75,51,384.8
|
||||
51,73,391.1
|
||||
73,71,289.3
|
||||
31,155,260.0
|
||||
155,34,320.4
|
||||
34,128,393.3
|
||||
145,115,399.4
|
||||
115,112,328.1
|
||||
112,8,469.4
|
||||
8,117,816.2
|
||||
117,125,397.1
|
||||
125,127,372.7
|
||||
127,109,380.5
|
||||
109,161,355.5
|
||||
161,110,367.7
|
||||
110,160,102.0
|
||||
72,159,342.9
|
||||
159,50,383.3
|
||||
50,74,354.1
|
||||
74,82,350.2
|
||||
82,81,335.4
|
||||
81,99,391.6
|
||||
99,84,354.9
|
||||
84,13,306.4
|
||||
13,40,327.4
|
||||
40,162,413.9
|
||||
162,108,301.9
|
||||
108,146,317.8
|
||||
146,85,376.6
|
||||
85,90,347.0
|
||||
26,27,341.6
|
||||
27,6,359.4
|
||||
6,149,417.8
|
||||
149,126,388.0
|
||||
126,124,384.3
|
||||
124,7,763.3
|
||||
7,114,323.1
|
||||
114,113,351.6
|
||||
113,116,411.9
|
||||
116,144,262.0
|
||||
25,103,350.2
|
||||
103,23,376.3
|
||||
23,165,396.4
|
||||
165,38,381.0
|
||||
38,92,368.0
|
||||
92,37,336.3
|
||||
37,130,357.8
|
||||
130,106,532.3
|
||||
106,131,166.5
|
||||
1,2,371.6
|
||||
2,4,338.1
|
||||
4,76,429.0
|
||||
76,36,366.1
|
||||
36,158,344.5
|
||||
158,151,350.1
|
||||
151,45,358.8
|
||||
45,93,340.9
|
||||
93,80,329.9
|
||||
80,79,384.1
|
||||
79,95,335.7
|
||||
95,98,320.9
|
||||
98,119,340.3
|
||||
119,120,376.8
|
||||
120,122,393.1
|
||||
122,141,428.7
|
||||
141,142,359.3
|
||||
30,165,379.6
|
||||
165,29,41.7
|
||||
29,38,343.3
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
17,153,375.6
|
||||
153,5,256.3
|
||||
153,62,330.9
|
||||
18,6,499.4
|
||||
6,22,254.0
|
||||
22,149,185.4
|
||||
22,4,257.9
|
||||
4,30,236.8
|
||||
30,76,307.0
|
||||
95,98,320.9
|
||||
98,144,45.1
|
||||
45,93,340.9
|
||||
93,106,112.2
|
||||
162,151,113.6
|
||||
151,108,192.9
|
||||
108,45,359.8
|
||||
146,92,311.2
|
||||
92,85,343.9
|
||||
85,37,373.2
|
||||
13,169,326.2
|
||||
169,40,96.1
|
||||
124,13,460.7
|
||||
13,7,305.5
|
||||
7,40,624.1
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
90,132,152.2
|
||||
26,32,106.7
|
||||
9,129,148.3
|
||||
129,153,219.6
|
||||
31,26,116.0
|
||||
26,155,270.7
|
||||
9,128,142.2
|
||||
128,153,215.0
|
||||
153,167,269.7
|
||||
167,62,64.8
|
||||
62,70,332.6
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
44,169,397.8
|
||||
169,41,124.0
|
||||
44,124,375.7
|
||||
124,41,243.9
|
||||
41,7,519.4
|
||||
6,14,289.3
|
||||
14,149,259.0
|
||||
149,39,206.9
|
||||
144,98,45.1
|
||||
19,4,326.8
|
||||
4,14,178.6
|
||||
14,76,299.0
|
||||
15,151,136.4
|
||||
151,44,203.1
|
||||
45,106,260.6
|
||||
106,93,112.2
|
||||
20,165,132.5
|
||||
165,19,289.2
|
||||
89,92,323.2
|
||||
92,147,321.9
|
||||
147,37,48.2
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
104,147,277.5
|
||||
147,148,84.7
|
||||
20,101,201.2
|
||||
101,19,534.4
|
||||
19,137,245.5
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
44,42,342.3
|
||||
42,41,102.5
|
||||
44,8,789.1
|
||||
8,41,657.4
|
||||
41,117,160.5
|
||||
168,167,172.4
|
||||
167,154,165.2
|
||||
143,128,81.9
|
||||
128,10,88.2
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
15,152,135.0
|
||||
152,44,204.6
|
||||
19,77,320.7
|
||||
77,14,299.8
|
||||
14,52,127.6
|
||||
14,127,314.8
|
||||
127,39,280.4
|
||||
39,109,237.0
|
||||
31,160,116.5
|
||||
160,155,272.4
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
32,160,107.7
|
||||
72,162,3274.4
|
||||
162,13,554.5
|
||||
162,40,413.9
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
13,42,319.8
|
||||
42,40,40.7
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
8,13,450.3
|
||||
13,117,378.5
|
||||
117,40,64.0
|
||||
46,162,391.6
|
||||
162,152,115.3
|
||||
152,108,191.4
|
||||
104,108,375.9
|
||||
108,148,311.6
|
||||
148,146,80.0
|
||||
21,90,396.9
|
||||
90,132,152.2
|
||||
101,29,252.3
|
||||
29,137,110.7
|
||||
77,22,353.8
|
||||
22,52,227.8
|
||||
52,30,186.6
|
||||
127,18,425.2
|
||||
18,109,439.1
|
||||
109,22,135.5
|
||||
168,17,232.7
|
||||
17,154,294.2
|
||||
154,5,166.3
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import os
|
||||
|
||||
def load_dataset(config):
|
||||
|
||||
dataset_name = config['basic']['dataset']
|
||||
node_num = config['data']['num_nodes']
|
||||
input_dim = config['data']['input_dim']
|
||||
|
|
@ -10,4 +11,8 @@ def load_dataset(config):
|
|||
case 'EcoSolar':
|
||||
data_path = os.path.join('./data/EcoSolar.npy')
|
||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||
case 'PEMS08':
|
||||
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
||||
|
||||
return data
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import numpy as np
|
||||
|
||||
def load_graph(config):
|
||||
dataset_path = config['data']['graph_pkl_filename']
|
||||
graph = np.load(dataset_path)
|
||||
# 将inf值填充为0
|
||||
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return graph
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 100
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 100
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
2
main.py
2
main.py
|
|
@ -3,8 +3,6 @@
|
|||
时空数据深度学习预测项目主程序
|
||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||
"""
|
||||
|
||||
import os
|
||||
from utils.args_reader import config_loader
|
||||
import utils.init as init
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
|
||||
from models.STDEN.stden_model import STDENModel
|
||||
|
||||
def model_selector(config):
|
||||
model_name = config['basic']['model']
|
||||
model = None
|
||||
match model_name:
|
||||
case 'STDEN': model = STDENModel(config)
|
||||
return model
|
||||
|
|
@ -0,0 +1,576 @@
|
|||
import math
|
||||
import os
|
||||
import time
|
||||
import copy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
||||
class Trainer:
|
||||
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
|
||||
scalers, logger, lr_scheduler=None):
|
||||
self.model = model
|
||||
self.loss = loss
|
||||
self.optimizer = optimizer
|
||||
self.train_loader = train_loader
|
||||
self.val_loader = val_loader
|
||||
self.test_loader = test_loader
|
||||
self.scalers = scalers # 现在是多个标准化器的列表
|
||||
self.args = config['train']
|
||||
self.logger = logger
|
||||
self.args['device'] = config['basic']['device']
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_per_epoch = len(train_loader)
|
||||
self.val_per_epoch = len(val_loader) if val_loader else 0
|
||||
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
|
||||
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
||||
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
||||
|
||||
# 用于收集nfe数据
|
||||
self.c = []
|
||||
self.res, self.keys = [], []
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
self.model.train()
|
||||
optimizer_step = True
|
||||
else:
|
||||
self.model.eval()
|
||||
optimizer_step = False
|
||||
|
||||
total_loss = 0
|
||||
epoch_time = time.time()
|
||||
|
||||
# 清空nfe数据收集
|
||||
if mode == 'train':
|
||||
self.c.clear()
|
||||
|
||||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output, fe = self.model(data)
|
||||
|
||||
if self.args['real_value']:
|
||||
# 只对输出维度进行反归一化
|
||||
output = self._inverse_transform_output(output)
|
||||
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 收集nfe数据(仅在训练模式下)
|
||||
if mode == 'train':
|
||||
self.c.append([*fe, loss.item()])
|
||||
# 记录FE信息
|
||||
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||
|
||||
if optimizer_step and self.optimizer is not None:
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
||||
if self.args['grad_norm']:
|
||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
||||
self.optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
||||
self.logger.info(
|
||||
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
|
||||
|
||||
# 更新 tqdm 的进度
|
||||
pbar.update(1)
|
||||
pbar.set_postfix(loss=loss.item())
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
|
||||
# 收集nfe数据(仅在训练模式下)
|
||||
if mode == 'train':
|
||||
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
||||
self.keys.append(epoch)
|
||||
|
||||
return avg_loss
|
||||
|
||||
def _inverse_transform_output(self, output):
|
||||
"""
|
||||
只对输出维度进行反归一化
|
||||
假设输出数据形状为 [batch, horizon, nodes, features]
|
||||
只对前output_dim个特征进行反归一化
|
||||
"""
|
||||
if not self.args['real_value']:
|
||||
return output
|
||||
|
||||
# 获取输出维度的数量
|
||||
output_dim = self.args['output_dim']
|
||||
|
||||
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||
if output_dim <= len(self.scalers):
|
||||
# 对每个输出特征分别进行反归一化
|
||||
for feature_idx in range(output_dim):
|
||||
if feature_idx < len(self.scalers):
|
||||
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||
output[..., feature_idx:feature_idx+1]
|
||||
)
|
||||
else:
|
||||
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||
for feature_idx in range(len(self.scalers)):
|
||||
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
||||
output[..., feature_idx:feature_idx+1]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
def train_epoch(self, epoch):
|
||||
return self._run_epoch(epoch, self.train_loader, 'train')
|
||||
|
||||
def val_epoch(self, epoch):
|
||||
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
|
||||
|
||||
def test_epoch(self, epoch):
|
||||
return self._run_epoch(epoch, self.test_loader, 'test')
|
||||
|
||||
def train(self):
|
||||
best_model, best_test_model = None, None
|
||||
best_loss, best_test_loss = float('inf'), float('inf')
|
||||
not_improved_count = 0
|
||||
|
||||
self.logger.logger.info("Training process started")
|
||||
for epoch in range(1, self.args['epochs'] + 1):
|
||||
train_epoch_loss = self.train_epoch(epoch)
|
||||
val_epoch_loss = self.val_epoch(epoch)
|
||||
test_epoch_loss = self.test_epoch(epoch)
|
||||
|
||||
if train_epoch_loss > 1e6:
|
||||
self.logger.logger.warning('Gradient explosion detected. Ending...')
|
||||
break
|
||||
|
||||
if val_epoch_loss < best_loss:
|
||||
best_loss = val_epoch_loss
|
||||
not_improved_count = 0
|
||||
best_model = copy.deepcopy(self.model.state_dict())
|
||||
torch.save(best_model, self.best_path)
|
||||
self.logger.logger.info('Best validation model saved!')
|
||||
else:
|
||||
not_improved_count += 1
|
||||
|
||||
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
|
||||
self.logger.logger.info(
|
||||
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
||||
break
|
||||
|
||||
if test_epoch_loss < best_test_loss:
|
||||
best_test_loss = test_epoch_loss
|
||||
best_test_model = copy.deepcopy(self.model.state_dict())
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
|
||||
# 保存nfe数据(如果启用)
|
||||
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
||||
self._save_nfe_data()
|
||||
|
||||
if not self.args['debug']:
|
||||
torch.save(best_model, self.best_path)
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
||||
|
||||
self._finalize_training(best_model, best_test_model)
|
||||
|
||||
def _finalize_training(self, best_model, best_test_model):
|
||||
self.model.load_state_dict(best_model)
|
||||
self.logger.logger.info("Testing on best validation model")
|
||||
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
|
||||
|
||||
self.model.load_state_dict(best_test_model)
|
||||
self.logger.logger.info("Testing on best test model")
|
||||
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
|
||||
|
||||
@staticmethod
|
||||
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
|
||||
if path:
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
model.to(args['device'])
|
||||
|
||||
model.eval()
|
||||
y_pred, y_true = [], []
|
||||
|
||||
# 用于收集nfe数据
|
||||
c = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in data_loader:
|
||||
label = target[..., :args['output_dim']]
|
||||
output, fe = model(data)
|
||||
y_pred.append(output)
|
||||
y_true.append(label)
|
||||
|
||||
# 收集nfe数据
|
||||
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
||||
|
||||
if args['real_value']:
|
||||
# 只对输出维度进行反归一化
|
||||
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
||||
else:
|
||||
y_pred = torch.cat(y_pred, dim=0)
|
||||
y_true = torch.cat(y_true, dim=0)
|
||||
|
||||
# 计算每个时间步的指标
|
||||
for t in range(y_true.shape[1]):
|
||||
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
||||
args['mae_thresh'], args['mape_thresh'])
|
||||
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
# 保存nfe数据(如果启用)
|
||||
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
||||
Trainer._save_nfe_data_static(c, model, logger)
|
||||
|
||||
# 只在需要时生成可视化图片
|
||||
if generate_viz:
|
||||
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
||||
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
|
||||
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||
target_node=1, num_samples=10, scalers=scalers)
|
||||
|
||||
@staticmethod
|
||||
def _inverse_transform_output_static(output, args, scalers):
|
||||
"""
|
||||
静态方法:只对输出维度进行反归一化
|
||||
"""
|
||||
if not args['real_value']:
|
||||
return output
|
||||
|
||||
# 获取输出维度的数量
|
||||
output_dim = args['output_dim']
|
||||
|
||||
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
||||
if output_dim <= len(scalers):
|
||||
# 对每个输出特征分别进行反归一化
|
||||
for feature_idx in range(output_dim):
|
||||
if feature_idx < len(scalers):
|
||||
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||
output[..., feature_idx:feature_idx+1]
|
||||
)
|
||||
else:
|
||||
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
||||
for feature_idx in range(len(scalers)):
|
||||
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
||||
output[..., feature_idx:feature_idx+1]
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
|
||||
"""
|
||||
生成节点预测可视化图片
|
||||
|
||||
Args:
|
||||
y_pred: 预测值
|
||||
y_true: 真实值
|
||||
logger: 日志记录器
|
||||
save_dir: 保存目录
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
|
||||
# 设置matplotlib配置,减少字体查找输出
|
||||
matplotlib.set_loglevel('error') # 只显示错误信息
|
||||
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
|
||||
|
||||
# 检查数据有效性
|
||||
if y_pred is None or y_true is None:
|
||||
return
|
||||
|
||||
# 创建pic文件夹
|
||||
pic_dir = os.path.join(save_dir, 'pic')
|
||||
os.makedirs(pic_dir, exist_ok=True)
|
||||
|
||||
# 固定生成10张图片
|
||||
num_nodes_to_plot = 10
|
||||
|
||||
# 生成单个节点的详细图
|
||||
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
|
||||
for node_id in range(num_nodes_to_plot):
|
||||
# 获取对应节点的数据
|
||||
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻,指定节点,第一个特征
|
||||
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
|
||||
else:
|
||||
# 如果数据不足10个节点,只处理实际存在的节点
|
||||
if node_id >= y_pred.shape[-2]:
|
||||
pbar.update(1)
|
||||
continue
|
||||
else:
|
||||
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
|
||||
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
|
||||
|
||||
# 检查数据有效性
|
||||
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# 取前500个时间步
|
||||
max_steps = min(500, len(node_pred))
|
||||
if max_steps <= 0:
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
node_pred_500 = node_pred[:max_steps]
|
||||
node_true_500 = node_true[:max_steps]
|
||||
|
||||
# 创建时间轴
|
||||
time_steps = np.arange(max_steps)
|
||||
|
||||
# 绘制对比图
|
||||
plt.figure(figsize=(12, 6))
|
||||
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
|
||||
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
|
||||
plt.xlabel('Time Steps')
|
||||
plt.ylabel('Values')
|
||||
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# 保存图片,使用不同的命名
|
||||
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# 生成所有节点的对比图(前100个时间步,便于观察)
|
||||
# 选择前100个时间步
|
||||
plot_steps = min(100, y_pred.shape[0])
|
||||
if plot_steps <= 0:
|
||||
return
|
||||
|
||||
# 创建子图
|
||||
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||
axes = axes.flatten()
|
||||
|
||||
for node_id in range(num_nodes_to_plot):
|
||||
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
||||
# 数据格式: [time_step, seq_len, num_node, dim]
|
||||
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||
else:
|
||||
# 如果数据不足10个节点,只处理实际存在的节点
|
||||
if node_id >= y_pred.shape[-2]:
|
||||
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||
continue
|
||||
else:
|
||||
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
||||
|
||||
# 检查数据有效性
|
||||
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
||||
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
||||
ha='center', va='center', transform=axes[node_id].transAxes)
|
||||
continue
|
||||
|
||||
time_steps = np.arange(plot_steps)
|
||||
|
||||
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
|
||||
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
|
||||
axes[node_id].set_title(f'Node {node_id + 1}')
|
||||
axes[node_id].grid(True, alpha=0.3)
|
||||
axes[node_id].legend(fontsize=8)
|
||||
|
||||
if node_id >= 5: # 下面一行添加x轴标签
|
||||
axes[node_id].set_xlabel('Time Steps')
|
||||
if node_id % 5 == 0: # 左边一列添加y轴标签
|
||||
axes[node_id].set_ylabel('Values')
|
||||
|
||||
plt.tight_layout()
|
||||
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
|
||||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
@staticmethod
|
||||
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
||||
target_node=1, num_samples=10, scalers=None):
|
||||
"""
|
||||
生成输入-输出样本比较图
|
||||
|
||||
Args:
|
||||
y_pred: 预测值
|
||||
y_true: 真实值
|
||||
data_loader: 数据加载器,用于获取输入数据
|
||||
logger: 日志记录器
|
||||
save_dir: 保存目录
|
||||
target_node: 目标节点ID(从1开始)
|
||||
num_samples: 要比较的样本数量
|
||||
scalers: 标准化器列表,用于反归一化输入数据
|
||||
"""
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import os
|
||||
import matplotlib
|
||||
from tqdm import tqdm
|
||||
|
||||
# 设置matplotlib配置
|
||||
matplotlib.set_loglevel('error')
|
||||
plt.rcParams['font.family'] = 'DejaVu Sans'
|
||||
|
||||
# 创建compare文件夹
|
||||
compare_dir = os.path.join(save_dir, 'pic', 'compare')
|
||||
os.makedirs(compare_dir, exist_ok=True)
|
||||
|
||||
# 获取输入数据
|
||||
input_data = []
|
||||
for batch_idx, (data, target) in enumerate(data_loader):
|
||||
if batch_idx >= num_samples:
|
||||
break
|
||||
input_data.append(data.cpu().numpy())
|
||||
|
||||
if not input_data:
|
||||
return
|
||||
|
||||
# 获取目标节点的索引(从0开始)
|
||||
node_idx = target_node - 1
|
||||
|
||||
# 检查节点索引是否有效
|
||||
if node_idx >= y_pred.shape[-2]:
|
||||
return
|
||||
|
||||
# 为每个样本生成比较图
|
||||
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
|
||||
for sample_idx in range(min(num_samples, len(input_data))):
|
||||
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features])
|
||||
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch,所有时间步,目标节点,第一个特征
|
||||
|
||||
# 对输入数据进行反归一化
|
||||
if scalers is not None and len(scalers) > 0:
|
||||
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
|
||||
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||
|
||||
# 获取对应的预测值和真实值
|
||||
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon,目标节点,第一个特征
|
||||
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||
|
||||
# 检查数据有效性
|
||||
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
|
||||
pbar.update(1)
|
||||
continue
|
||||
|
||||
# 创建时间轴 - 输入和输出连续
|
||||
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||
|
||||
# 创建合并的图形 - 输入和输出在同一个图中
|
||||
plt.figure(figsize=(14, 8))
|
||||
|
||||
# 绘制完整的真实值曲线(输入 + 真实输出)
|
||||
true_combined = np.concatenate([input_seq, true_seq])
|
||||
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
|
||||
linewidth=2.5, alpha=0.9, linestyle='-')
|
||||
|
||||
# 绘制预测值曲线(只绘制输出部分)
|
||||
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
|
||||
linewidth=2, alpha=0.8, linestyle='-')
|
||||
|
||||
# 添加垂直线分隔输入和输出
|
||||
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
|
||||
label='Input/Output Boundary')
|
||||
|
||||
# 设置图形属性
|
||||
plt.xlabel('Time Steps')
|
||||
plt.ylabel('Values')
|
||||
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
|
||||
plt.legend()
|
||||
plt.grid(True, alpha=0.3)
|
||||
|
||||
# 调整布局
|
||||
plt.tight_layout()
|
||||
|
||||
# 保存图片
|
||||
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
|
||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
# 生成汇总图(所有样本的预测值对比)
|
||||
|
||||
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
||||
axes = axes.flatten()
|
||||
|
||||
for sample_idx in range(min(num_samples, len(input_data))):
|
||||
if sample_idx >= 10: # 最多显示10个子图
|
||||
break
|
||||
|
||||
ax = axes[sample_idx]
|
||||
|
||||
# 获取输入序列和预测值、真实值
|
||||
input_seq = input_data[sample_idx][0, :, node_idx, 0]
|
||||
if scalers is not None and len(scalers) > 0:
|
||||
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
||||
|
||||
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
||||
|
||||
# 检查数据有效性
|
||||
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
|
||||
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
|
||||
ha='center', va='center', transform=ax.transAxes)
|
||||
continue
|
||||
|
||||
# 绘制对比图 - 输入和输出连续显示
|
||||
total_time = np.arange(len(input_seq) + len(pred_seq))
|
||||
true_combined = np.concatenate([input_seq, true_seq])
|
||||
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
||||
|
||||
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
|
||||
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
|
||||
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
|
||||
ax.set_title(f'Sample {sample_idx + 1}')
|
||||
ax.grid(True, alpha=0.3)
|
||||
ax.legend(fontsize=8)
|
||||
|
||||
if sample_idx >= 5: # 下面一行添加x轴标签
|
||||
ax.set_xlabel('Time Steps')
|
||||
if sample_idx % 5 == 0: # 左边一列添加y轴标签
|
||||
ax.set_ylabel('Values')
|
||||
|
||||
# 隐藏多余的子图
|
||||
for i in range(min(num_samples, len(input_data)), 10):
|
||||
axes[i].set_visible(False)
|
||||
|
||||
plt.tight_layout()
|
||||
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
|
||||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def _save_nfe_data(self):
|
||||
"""保存nfe数据到文件"""
|
||||
if not self.res:
|
||||
return
|
||||
|
||||
res = pd.concat(self.res, keys=self.keys)
|
||||
res.index.names = ['epoch', 'iter']
|
||||
|
||||
# 获取模型配置参数
|
||||
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
||||
atol = getattr(self.model, 'atol', 1e-5)
|
||||
rtol = getattr(self.model, 'rtol', 1e-5)
|
||||
|
||||
# 保存nfe数据
|
||||
nfe_file = os.path.join(
|
||||
self.logger.dir_path,
|
||||
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||
res.to_pickle(nfe_file)
|
||||
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_sampling_threshold(global_step, k):
|
||||
return k / (k + math.exp(global_step / k))
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
from trainer.trainer import Trainer
|
||||
from trainer.ode_trainer import Trainer as ode_trainer
|
||||
|
||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||
lr_scheduler, kwargs):
|
||||
model_name = config['basic']['model']
|
||||
selected_Trainer = None
|
||||
match model_name:
|
||||
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||
if selected_Trainer is None: raise NotImplementedError
|
||||
return selected_Trainer
|
||||
Loading…
Reference in New Issue