Compare commits
No commits in common. "66a23ffbbb2087525fb3a38bb10ad1b7d6183512" and "19a02ba7ae9e4ec89a46ff4cb1141fb51323e762" have entirely different histories.
66a23ffbbb
...
19a02ba7ae
|
|
@ -160,4 +160,3 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
STDEN/
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
# 默认忽略的文件
|
|
||||||
/shelf/
|
|
||||||
/workspace.xml
|
|
||||||
# 基于编辑器的 HTTP 客户端请求
|
|
||||||
/httpRequests/
|
|
||||||
# Datasource local storage ignored files
|
|
||||||
/dataSources/
|
|
||||||
/dataSources.local.xml
|
|
||||||
|
|
@ -1,12 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<module type="PYTHON_MODULE" version="4">
|
|
||||||
<component name="NewModuleRootManager">
|
|
||||||
<content url="file://$MODULE_DIR$" />
|
|
||||||
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
|
||||||
<orderEntry type="sourceFolder" forTests="false" />
|
|
||||||
</component>
|
|
||||||
<component name="PyDocumentationSettings">
|
|
||||||
<option name="format" value="PLAIN" />
|
|
||||||
<option name="myDocStringFormat" value="Plain" />
|
|
||||||
</component>
|
|
||||||
</module>
|
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<profile version="1.0">
|
|
||||||
<option name="myName" value="Project Default" />
|
|
||||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
|
||||||
<Languages>
|
|
||||||
<language minSize="136" name="Python" />
|
|
||||||
</Languages>
|
|
||||||
</inspection_tool>
|
|
||||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
|
||||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
|
||||||
<option name="ignoredPackages">
|
|
||||||
<value>
|
|
||||||
<list size="8">
|
|
||||||
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
|
||||||
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
|
||||||
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
|
||||||
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
|
||||||
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
|
||||||
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
|
||||||
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
|
||||||
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
|
||||||
</list>
|
|
||||||
</value>
|
|
||||||
</option>
|
|
||||||
</inspection_tool>
|
|
||||||
</profile>
|
|
||||||
</component>
|
|
||||||
|
|
@ -1,6 +0,0 @@
|
||||||
<component name="InspectionProjectProfileManager">
|
|
||||||
<settings>
|
|
||||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
|
||||||
<version value="1.0" />
|
|
||||||
</settings>
|
|
||||||
</component>
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="Black">
|
|
||||||
<option name="sdkName" value="Python 3.10" />
|
|
||||||
</component>
|
|
||||||
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
|
||||||
</project>
|
|
||||||
|
|
@ -1,8 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="ProjectModuleManager">
|
|
||||||
<modules>
|
|
||||||
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
|
||||||
</modules>
|
|
||||||
</component>
|
|
||||||
</project>
|
|
||||||
|
|
@ -1,7 +0,0 @@
|
||||||
<?xml version="1.0" encoding="UTF-8"?>
|
|
||||||
<project version="4">
|
|
||||||
<component name="VcsDirectoryMappings">
|
|
||||||
<mapping directory="" vcs="Git" />
|
|
||||||
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
|
||||||
</component>
|
|
||||||
</project>
|
|
||||||
|
|
@ -1,296 +0,0 @@
|
||||||
from,to,cost
|
|
||||||
9,153,310.6
|
|
||||||
153,62,330.9
|
|
||||||
62,111,332.9
|
|
||||||
111,11,324.2
|
|
||||||
11,28,336.0
|
|
||||||
28,169,133.7
|
|
||||||
138,135,354.7
|
|
||||||
135,133,387.9
|
|
||||||
133,163,337.1
|
|
||||||
163,20,352.0
|
|
||||||
20,19,420.8
|
|
||||||
19,14,351.3
|
|
||||||
14,39,340.2
|
|
||||||
39,164,350.3
|
|
||||||
164,167,365.2
|
|
||||||
167,70,359.0
|
|
||||||
70,59,388.2
|
|
||||||
59,58,305.7
|
|
||||||
58,67,294.4
|
|
||||||
67,66,299.5
|
|
||||||
66,55,313.3
|
|
||||||
55,53,332.1
|
|
||||||
53,150,278.9
|
|
||||||
150,61,308.4
|
|
||||||
61,64,311.4
|
|
||||||
64,63,243.6
|
|
||||||
47,65,372.8
|
|
||||||
65,48,319.4
|
|
||||||
48,49,309.7
|
|
||||||
49,54,320.5
|
|
||||||
54,56,318.3
|
|
||||||
56,57,297.9
|
|
||||||
57,68,293.5
|
|
||||||
68,69,342.5
|
|
||||||
69,60,318.0
|
|
||||||
60,17,305.9
|
|
||||||
17,5,321.4
|
|
||||||
5,18,402.2
|
|
||||||
18,22,447.4
|
|
||||||
22,30,377.5
|
|
||||||
30,29,417.7
|
|
||||||
29,21,360.8
|
|
||||||
21,132,407.6
|
|
||||||
132,134,386.9
|
|
||||||
134,136,350.2
|
|
||||||
123,121,326.3
|
|
||||||
121,140,385.2
|
|
||||||
140,118,393.0
|
|
||||||
118,96,296.7
|
|
||||||
96,94,398.2
|
|
||||||
94,86,337.1
|
|
||||||
86,78,473.8
|
|
||||||
78,46,353.4
|
|
||||||
46,152,385.7
|
|
||||||
152,157,350.0
|
|
||||||
157,35,354.4
|
|
||||||
35,77,356.1
|
|
||||||
77,52,354.2
|
|
||||||
52,3,357.8
|
|
||||||
3,16,382.4
|
|
||||||
16,0,55.7
|
|
||||||
42,12,335.1
|
|
||||||
12,139,328.8
|
|
||||||
139,168,412.6
|
|
||||||
168,154,337.3
|
|
||||||
154,143,370.7
|
|
||||||
143,10,6.3
|
|
||||||
107,105,354.6
|
|
||||||
105,104,386.9
|
|
||||||
104,148,362.1
|
|
||||||
148,97,316.3
|
|
||||||
97,101,380.7
|
|
||||||
101,137,361.4
|
|
||||||
137,102,365.5
|
|
||||||
102,24,375.5
|
|
||||||
24,166,312.2
|
|
||||||
129,156,256.1
|
|
||||||
156,33,329.1
|
|
||||||
33,32,356.5
|
|
||||||
91,89,405.6
|
|
||||||
89,147,347.0
|
|
||||||
147,15,351.7
|
|
||||||
15,44,339.5
|
|
||||||
44,41,350.8
|
|
||||||
41,43,322.6
|
|
||||||
43,100,338.9
|
|
||||||
100,83,347.9
|
|
||||||
83,87,327.2
|
|
||||||
87,88,321.0
|
|
||||||
88,75,335.8
|
|
||||||
75,51,384.8
|
|
||||||
51,73,391.1
|
|
||||||
73,71,289.3
|
|
||||||
31,155,260.0
|
|
||||||
155,34,320.4
|
|
||||||
34,128,393.3
|
|
||||||
145,115,399.4
|
|
||||||
115,112,328.1
|
|
||||||
112,8,469.4
|
|
||||||
8,117,816.2
|
|
||||||
117,125,397.1
|
|
||||||
125,127,372.7
|
|
||||||
127,109,380.5
|
|
||||||
109,161,355.5
|
|
||||||
161,110,367.7
|
|
||||||
110,160,102.0
|
|
||||||
72,159,342.9
|
|
||||||
159,50,383.3
|
|
||||||
50,74,354.1
|
|
||||||
74,82,350.2
|
|
||||||
82,81,335.4
|
|
||||||
81,99,391.6
|
|
||||||
99,84,354.9
|
|
||||||
84,13,306.4
|
|
||||||
13,40,327.4
|
|
||||||
40,162,413.9
|
|
||||||
162,108,301.9
|
|
||||||
108,146,317.8
|
|
||||||
146,85,376.6
|
|
||||||
85,90,347.0
|
|
||||||
26,27,341.6
|
|
||||||
27,6,359.4
|
|
||||||
6,149,417.8
|
|
||||||
149,126,388.0
|
|
||||||
126,124,384.3
|
|
||||||
124,7,763.3
|
|
||||||
7,114,323.1
|
|
||||||
114,113,351.6
|
|
||||||
113,116,411.9
|
|
||||||
116,144,262.0
|
|
||||||
25,103,350.2
|
|
||||||
103,23,376.3
|
|
||||||
23,165,396.4
|
|
||||||
165,38,381.0
|
|
||||||
38,92,368.0
|
|
||||||
92,37,336.3
|
|
||||||
37,130,357.8
|
|
||||||
130,106,532.3
|
|
||||||
106,131,166.5
|
|
||||||
1,2,371.6
|
|
||||||
2,4,338.1
|
|
||||||
4,76,429.0
|
|
||||||
76,36,366.1
|
|
||||||
36,158,344.5
|
|
||||||
158,151,350.1
|
|
||||||
151,45,358.8
|
|
||||||
45,93,340.9
|
|
||||||
93,80,329.9
|
|
||||||
80,79,384.1
|
|
||||||
79,95,335.7
|
|
||||||
95,98,320.9
|
|
||||||
98,119,340.3
|
|
||||||
119,120,376.8
|
|
||||||
120,122,393.1
|
|
||||||
122,141,428.7
|
|
||||||
141,142,359.3
|
|
||||||
30,165,379.6
|
|
||||||
165,29,41.7
|
|
||||||
29,38,343.3
|
|
||||||
65,72,297.9
|
|
||||||
72,48,21.5
|
|
||||||
17,153,375.6
|
|
||||||
153,5,256.3
|
|
||||||
153,62,330.9
|
|
||||||
18,6,499.4
|
|
||||||
6,22,254.0
|
|
||||||
22,149,185.4
|
|
||||||
22,4,257.9
|
|
||||||
4,30,236.8
|
|
||||||
30,76,307.0
|
|
||||||
95,98,320.9
|
|
||||||
98,144,45.1
|
|
||||||
45,93,340.9
|
|
||||||
93,106,112.2
|
|
||||||
162,151,113.6
|
|
||||||
151,108,192.9
|
|
||||||
108,45,359.8
|
|
||||||
146,92,311.2
|
|
||||||
92,85,343.9
|
|
||||||
85,37,373.2
|
|
||||||
13,169,326.2
|
|
||||||
169,40,96.1
|
|
||||||
124,13,460.7
|
|
||||||
13,7,305.5
|
|
||||||
7,40,624.1
|
|
||||||
124,169,145.2
|
|
||||||
169,7,631.5
|
|
||||||
90,132,152.2
|
|
||||||
26,32,106.7
|
|
||||||
9,129,148.3
|
|
||||||
129,153,219.6
|
|
||||||
31,26,116.0
|
|
||||||
26,155,270.7
|
|
||||||
9,128,142.2
|
|
||||||
128,153,215.0
|
|
||||||
153,167,269.7
|
|
||||||
167,62,64.8
|
|
||||||
62,70,332.6
|
|
||||||
124,169,145.2
|
|
||||||
169,7,631.5
|
|
||||||
44,169,397.8
|
|
||||||
169,41,124.0
|
|
||||||
44,124,375.7
|
|
||||||
124,41,243.9
|
|
||||||
41,7,519.4
|
|
||||||
6,14,289.3
|
|
||||||
14,149,259.0
|
|
||||||
149,39,206.9
|
|
||||||
144,98,45.1
|
|
||||||
19,4,326.8
|
|
||||||
4,14,178.6
|
|
||||||
14,76,299.0
|
|
||||||
15,151,136.4
|
|
||||||
151,44,203.1
|
|
||||||
45,106,260.6
|
|
||||||
106,93,112.2
|
|
||||||
20,165,132.5
|
|
||||||
165,19,289.2
|
|
||||||
89,92,323.2
|
|
||||||
92,147,321.9
|
|
||||||
147,37,48.2
|
|
||||||
133,91,152.8
|
|
||||||
91,163,313.6
|
|
||||||
150,71,221.1
|
|
||||||
71,61,89.6
|
|
||||||
78,107,143.9
|
|
||||||
107,46,236.3
|
|
||||||
104,147,277.5
|
|
||||||
147,148,84.7
|
|
||||||
20,101,201.2
|
|
||||||
101,19,534.4
|
|
||||||
19,137,245.5
|
|
||||||
8,42,759.5
|
|
||||||
42,117,58.9
|
|
||||||
44,42,342.3
|
|
||||||
42,41,102.5
|
|
||||||
44,8,789.1
|
|
||||||
8,41,657.4
|
|
||||||
41,117,160.5
|
|
||||||
168,167,172.4
|
|
||||||
167,154,165.2
|
|
||||||
143,128,81.9
|
|
||||||
128,10,88.2
|
|
||||||
118,145,250.6
|
|
||||||
145,96,85.1
|
|
||||||
15,152,135.0
|
|
||||||
152,44,204.6
|
|
||||||
19,77,320.7
|
|
||||||
77,14,299.8
|
|
||||||
14,52,127.6
|
|
||||||
14,127,314.8
|
|
||||||
127,39,280.4
|
|
||||||
39,109,237.0
|
|
||||||
31,160,116.5
|
|
||||||
160,155,272.4
|
|
||||||
133,91,152.8
|
|
||||||
91,163,313.6
|
|
||||||
150,71,221.1
|
|
||||||
71,61,89.6
|
|
||||||
32,160,107.7
|
|
||||||
72,162,3274.4
|
|
||||||
162,13,554.5
|
|
||||||
162,40,413.9
|
|
||||||
65,72,297.9
|
|
||||||
72,48,21.5
|
|
||||||
13,42,319.8
|
|
||||||
42,40,40.7
|
|
||||||
8,42,759.5
|
|
||||||
42,117,58.9
|
|
||||||
8,13,450.3
|
|
||||||
13,117,378.5
|
|
||||||
117,40,64.0
|
|
||||||
46,162,391.6
|
|
||||||
162,152,115.3
|
|
||||||
152,108,191.4
|
|
||||||
104,108,375.9
|
|
||||||
108,148,311.6
|
|
||||||
148,146,80.0
|
|
||||||
21,90,396.9
|
|
||||||
90,132,152.2
|
|
||||||
101,29,252.3
|
|
||||||
29,137,110.7
|
|
||||||
77,22,353.8
|
|
||||||
22,52,227.8
|
|
||||||
52,30,186.6
|
|
||||||
127,18,425.2
|
|
||||||
18,109,439.1
|
|
||||||
109,22,135.5
|
|
||||||
168,17,232.7
|
|
||||||
17,154,294.2
|
|
||||||
154,5,166.3
|
|
||||||
78,107,143.9
|
|
||||||
107,46,236.3
|
|
||||||
118,145,250.6
|
|
||||||
145,96,85.1
|
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,7 +2,6 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def load_dataset(config):
|
def load_dataset(config):
|
||||||
|
|
||||||
dataset_name = config['basic']['dataset']
|
dataset_name = config['basic']['dataset']
|
||||||
node_num = config['data']['num_nodes']
|
node_num = config['data']['num_nodes']
|
||||||
input_dim = config['data']['input_dim']
|
input_dim = config['data']['input_dim']
|
||||||
|
|
@ -11,8 +10,4 @@ def load_dataset(config):
|
||||||
case 'EcoSolar':
|
case 'EcoSolar':
|
||||||
data_path = os.path.join('./data/EcoSolar.npy')
|
data_path = os.path.join('./data/EcoSolar.npy')
|
||||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||||
case 'PEMS08':
|
|
||||||
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
|
||||||
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -1,9 +0,0 @@
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
def load_graph(config):
|
|
||||||
dataset_path = config['data']['graph_pkl_filename']
|
|
||||||
graph = np.load(dataset_path)
|
|
||||||
# 将inf值填充为0
|
|
||||||
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 100
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 100
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 170
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 170
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 170
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
|
|
@ -1,62 +0,0 @@
|
||||||
basic:
|
|
||||||
dataset: PEMS08
|
|
||||||
device: cuda:0
|
|
||||||
mode: train
|
|
||||||
model: STDEN
|
|
||||||
seed: 2025
|
|
||||||
data:
|
|
||||||
add_day_in_week: true
|
|
||||||
add_time_in_day: true
|
|
||||||
batch_size: 32
|
|
||||||
column_wise: false
|
|
||||||
dataset_dir: data/PEMS08
|
|
||||||
days_per_week: 7
|
|
||||||
default_graph: true
|
|
||||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
|
||||||
horizon: 24
|
|
||||||
input_dim: 1
|
|
||||||
lag: 24
|
|
||||||
normalizer: std
|
|
||||||
num_nodes: 170
|
|
||||||
steps_per_day: 24
|
|
||||||
test_ratio: 0.2
|
|
||||||
tod: false
|
|
||||||
val_batch_size: 32
|
|
||||||
val_ratio: 0.2
|
|
||||||
model:
|
|
||||||
filter_type: default
|
|
||||||
gcn_step: 2
|
|
||||||
horizon: 12
|
|
||||||
input_dim: 1
|
|
||||||
l1_decay: 0
|
|
||||||
latent_dim: 4
|
|
||||||
n_traj_samples: 3
|
|
||||||
nfe: false
|
|
||||||
num_rnn_layers: 1
|
|
||||||
ode_method: dopri5
|
|
||||||
odeint_atol: 1.0e-05
|
|
||||||
odeint_rtol: 1.0e-05
|
|
||||||
output_dim: 1
|
|
||||||
recg_type: gru
|
|
||||||
rnn_units: 64
|
|
||||||
save_latent: false
|
|
||||||
seq_len: 12
|
|
||||||
train:
|
|
||||||
batch_size: 64
|
|
||||||
debug: false
|
|
||||||
early_stop: true
|
|
||||||
early_stop_patience: 15
|
|
||||||
epochs: 100
|
|
||||||
grad_norm: false
|
|
||||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
|
|
||||||
loss: mae
|
|
||||||
lr_decay: false
|
|
||||||
lr_decay_rate: 0.3
|
|
||||||
lr_decay_step: 5,20,40,70
|
|
||||||
lr_init: 0.003
|
|
||||||
mae_thresh: None
|
|
||||||
mape_thresh: 0.001
|
|
||||||
max_grad_norm: 5
|
|
||||||
output_dim: 1
|
|
||||||
real_value: true
|
|
||||||
weight_decay: 0
|
|
||||||
2
main.py
2
main.py
|
|
@ -3,6 +3,8 @@
|
||||||
时空数据深度学习预测项目主程序
|
时空数据深度学习预测项目主程序
|
||||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
from utils.args_reader import config_loader
|
from utils.args_reader import config_loader
|
||||||
import utils.init as init
|
import utils.init as init
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,6 @@
|
||||||
from models.STDEN.stden_model import STDENModel
|
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
model = None
|
model = None
|
||||||
match model_name:
|
|
||||||
case 'STDEN': model = STDENModel(config)
|
|
||||||
return model
|
return model
|
||||||
|
|
@ -1,576 +0,0 @@
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import copy
|
|
||||||
import pandas as pd
|
|
||||||
import numpy as np
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
class Trainer:
|
|
||||||
def __init__(self, config, model, loss, optimizer, train_loader, val_loader, test_loader,
|
|
||||||
scalers, logger, lr_scheduler=None):
|
|
||||||
self.model = model
|
|
||||||
self.loss = loss
|
|
||||||
self.optimizer = optimizer
|
|
||||||
self.train_loader = train_loader
|
|
||||||
self.val_loader = val_loader
|
|
||||||
self.test_loader = test_loader
|
|
||||||
self.scalers = scalers # 现在是多个标准化器的列表
|
|
||||||
self.args = config['train']
|
|
||||||
self.logger = logger
|
|
||||||
self.args['device'] = config['basic']['device']
|
|
||||||
self.lr_scheduler = lr_scheduler
|
|
||||||
self.train_per_epoch = len(train_loader)
|
|
||||||
self.val_per_epoch = len(val_loader) if val_loader else 0
|
|
||||||
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
|
|
||||||
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
|
||||||
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
|
||||||
|
|
||||||
# 用于收集nfe数据
|
|
||||||
self.c = []
|
|
||||||
self.res, self.keys = [], []
|
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
|
||||||
if mode == 'train':
|
|
||||||
self.model.train()
|
|
||||||
optimizer_step = True
|
|
||||||
else:
|
|
||||||
self.model.eval()
|
|
||||||
optimizer_step = False
|
|
||||||
|
|
||||||
total_loss = 0
|
|
||||||
epoch_time = time.time()
|
|
||||||
|
|
||||||
# 清空nfe数据收集
|
|
||||||
if mode == 'train':
|
|
||||||
self.c.clear()
|
|
||||||
|
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
|
||||||
label = target[..., :self.args['output_dim']]
|
|
||||||
output, fe = self.model(data)
|
|
||||||
|
|
||||||
if self.args['real_value']:
|
|
||||||
# 只对输出维度进行反归一化
|
|
||||||
output = self._inverse_transform_output(output)
|
|
||||||
|
|
||||||
loss = self.loss(output, label)
|
|
||||||
|
|
||||||
# 收集nfe数据(仅在训练模式下)
|
|
||||||
if mode == 'train':
|
|
||||||
self.c.append([*fe, loss.item()])
|
|
||||||
# 记录FE信息
|
|
||||||
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
|
||||||
|
|
||||||
if optimizer_step and self.optimizer is not None:
|
|
||||||
self.optimizer.zero_grad()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
if self.args['grad_norm']:
|
|
||||||
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args['max_grad_norm'])
|
|
||||||
self.optimizer.step()
|
|
||||||
|
|
||||||
total_loss += loss.item()
|
|
||||||
|
|
||||||
if mode == 'train' and (batch_idx + 1) % self.args['log_step'] == 0:
|
|
||||||
self.logger.info(
|
|
||||||
f'Train Epoch {epoch}: {batch_idx + 1}/{len(dataloader)} Loss: {loss.item():.6f}')
|
|
||||||
|
|
||||||
# 更新 tqdm 的进度
|
|
||||||
pbar.update(1)
|
|
||||||
pbar.set_postfix(loss=loss.item())
|
|
||||||
|
|
||||||
avg_loss = total_loss / len(dataloader)
|
|
||||||
self.logger.logger.info(
|
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
|
||||||
|
|
||||||
# 收集nfe数据(仅在训练模式下)
|
|
||||||
if mode == 'train':
|
|
||||||
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
|
||||||
self.keys.append(epoch)
|
|
||||||
|
|
||||||
return avg_loss
|
|
||||||
|
|
||||||
def _inverse_transform_output(self, output):
|
|
||||||
"""
|
|
||||||
只对输出维度进行反归一化
|
|
||||||
假设输出数据形状为 [batch, horizon, nodes, features]
|
|
||||||
只对前output_dim个特征进行反归一化
|
|
||||||
"""
|
|
||||||
if not self.args['real_value']:
|
|
||||||
return output
|
|
||||||
|
|
||||||
# 获取输出维度的数量
|
|
||||||
output_dim = self.args['output_dim']
|
|
||||||
|
|
||||||
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
|
||||||
if output_dim <= len(self.scalers):
|
|
||||||
# 对每个输出特征分别进行反归一化
|
|
||||||
for feature_idx in range(output_dim):
|
|
||||||
if feature_idx < len(self.scalers):
|
|
||||||
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
|
||||||
output[..., feature_idx:feature_idx+1]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
|
||||||
for feature_idx in range(len(self.scalers)):
|
|
||||||
output[..., feature_idx:feature_idx+1] = self.scalers[feature_idx].inverse_transform(
|
|
||||||
output[..., feature_idx:feature_idx+1]
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
def train_epoch(self, epoch):
|
|
||||||
return self._run_epoch(epoch, self.train_loader, 'train')
|
|
||||||
|
|
||||||
def val_epoch(self, epoch):
|
|
||||||
return self._run_epoch(epoch, self.val_loader or self.test_loader, 'val')
|
|
||||||
|
|
||||||
def test_epoch(self, epoch):
|
|
||||||
return self._run_epoch(epoch, self.test_loader, 'test')
|
|
||||||
|
|
||||||
def train(self):
|
|
||||||
best_model, best_test_model = None, None
|
|
||||||
best_loss, best_test_loss = float('inf'), float('inf')
|
|
||||||
not_improved_count = 0
|
|
||||||
|
|
||||||
self.logger.logger.info("Training process started")
|
|
||||||
for epoch in range(1, self.args['epochs'] + 1):
|
|
||||||
train_epoch_loss = self.train_epoch(epoch)
|
|
||||||
val_epoch_loss = self.val_epoch(epoch)
|
|
||||||
test_epoch_loss = self.test_epoch(epoch)
|
|
||||||
|
|
||||||
if train_epoch_loss > 1e6:
|
|
||||||
self.logger.logger.warning('Gradient explosion detected. Ending...')
|
|
||||||
break
|
|
||||||
|
|
||||||
if val_epoch_loss < best_loss:
|
|
||||||
best_loss = val_epoch_loss
|
|
||||||
not_improved_count = 0
|
|
||||||
best_model = copy.deepcopy(self.model.state_dict())
|
|
||||||
torch.save(best_model, self.best_path)
|
|
||||||
self.logger.logger.info('Best validation model saved!')
|
|
||||||
else:
|
|
||||||
not_improved_count += 1
|
|
||||||
|
|
||||||
if self.args['early_stop'] and not_improved_count == self.args['early_stop_patience']:
|
|
||||||
self.logger.logger.info(
|
|
||||||
f"Validation performance didn't improve for {self.args['early_stop_patience']} epochs. Training stops.")
|
|
||||||
break
|
|
||||||
|
|
||||||
if test_epoch_loss < best_test_loss:
|
|
||||||
best_test_loss = test_epoch_loss
|
|
||||||
best_test_model = copy.deepcopy(self.model.state_dict())
|
|
||||||
torch.save(best_test_model, self.best_test_path)
|
|
||||||
|
|
||||||
# 保存nfe数据(如果启用)
|
|
||||||
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
|
||||||
self._save_nfe_data()
|
|
||||||
|
|
||||||
if not self.args['debug']:
|
|
||||||
torch.save(best_model, self.best_path)
|
|
||||||
torch.save(best_test_model, self.best_test_path)
|
|
||||||
self.logger.logger.info(f"Best models saved at {self.best_path} and {self.best_test_path}")
|
|
||||||
|
|
||||||
self._finalize_training(best_model, best_test_model)
|
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
|
||||||
self.model.load_state_dict(best_model)
|
|
||||||
self.logger.logger.info("Testing on best validation model")
|
|
||||||
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=False)
|
|
||||||
|
|
||||||
self.model.load_state_dict(best_test_model)
|
|
||||||
self.logger.logger.info("Testing on best test model")
|
|
||||||
self.test(self.model, self.args, self.test_loader, self.scalers, self.logger, generate_viz=True)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def test(model, args, data_loader, scalers, logger, path=None, generate_viz=True):
|
|
||||||
if path:
|
|
||||||
checkpoint = torch.load(path)
|
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
|
||||||
model.to(args['device'])
|
|
||||||
|
|
||||||
model.eval()
|
|
||||||
y_pred, y_true = [], []
|
|
||||||
|
|
||||||
# 用于收集nfe数据
|
|
||||||
c = []
|
|
||||||
|
|
||||||
with torch.no_grad():
|
|
||||||
for data, target in data_loader:
|
|
||||||
label = target[..., :args['output_dim']]
|
|
||||||
output, fe = model(data)
|
|
||||||
y_pred.append(output)
|
|
||||||
y_true.append(label)
|
|
||||||
|
|
||||||
# 收集nfe数据
|
|
||||||
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
|
||||||
|
|
||||||
if args['real_value']:
|
|
||||||
# 只对输出维度进行反归一化
|
|
||||||
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
|
||||||
else:
|
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
|
||||||
y_true = torch.cat(y_true, dim=0)
|
|
||||||
|
|
||||||
# 计算每个时间步的指标
|
|
||||||
for t in range(y_true.shape[1]):
|
|
||||||
mae, rmse, mape = logger.all_metrics(y_pred[:, t, ...], y_true[:, t, ...],
|
|
||||||
args['mae_thresh'], args['mape_thresh'])
|
|
||||||
logger.logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
|
||||||
|
|
||||||
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
|
||||||
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
|
||||||
|
|
||||||
# 保存nfe数据(如果启用)
|
|
||||||
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
|
||||||
Trainer._save_nfe_data_static(c, model, logger)
|
|
||||||
|
|
||||||
# 只在需要时生成可视化图片
|
|
||||||
if generate_viz:
|
|
||||||
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
|
||||||
Trainer._generate_node_visualizations(y_pred, y_true, logger, save_dir)
|
|
||||||
Trainer._generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
|
||||||
target_node=1, num_samples=10, scalers=scalers)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _inverse_transform_output_static(output, args, scalers):
|
|
||||||
"""
|
|
||||||
静态方法:只对输出维度进行反归一化
|
|
||||||
"""
|
|
||||||
if not args['real_value']:
|
|
||||||
return output
|
|
||||||
|
|
||||||
# 获取输出维度的数量
|
|
||||||
output_dim = args['output_dim']
|
|
||||||
|
|
||||||
# 如果输出特征数小于等于标准化器数量,直接使用对应的标准化器
|
|
||||||
if output_dim <= len(scalers):
|
|
||||||
# 对每个输出特征分别进行反归一化
|
|
||||||
for feature_idx in range(output_dim):
|
|
||||||
if feature_idx < len(scalers):
|
|
||||||
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
|
||||||
output[..., feature_idx:feature_idx+1]
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# 如果输出特征数大于标准化器数量,只对前len(scalers)个特征进行反归一化
|
|
||||||
for feature_idx in range(len(scalers)):
|
|
||||||
output[..., feature_idx:feature_idx+1] = scalers[feature_idx].inverse_transform(
|
|
||||||
output[..., feature_idx:feature_idx+1]
|
|
||||||
)
|
|
||||||
|
|
||||||
return output
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_node_visualizations(y_pred, y_true, logger, save_dir):
|
|
||||||
"""
|
|
||||||
生成节点预测可视化图片
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_pred: 预测值
|
|
||||||
y_true: 真实值
|
|
||||||
logger: 日志记录器
|
|
||||||
save_dir: 保存目录
|
|
||||||
"""
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import matplotlib
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# 设置matplotlib配置,减少字体查找输出
|
|
||||||
matplotlib.set_loglevel('error') # 只显示错误信息
|
|
||||||
plt.rcParams['font.family'] = 'DejaVu Sans' # 使用默认字体
|
|
||||||
|
|
||||||
# 检查数据有效性
|
|
||||||
if y_pred is None or y_true is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 创建pic文件夹
|
|
||||||
pic_dir = os.path.join(save_dir, 'pic')
|
|
||||||
os.makedirs(pic_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 固定生成10张图片
|
|
||||||
num_nodes_to_plot = 10
|
|
||||||
|
|
||||||
# 生成单个节点的详细图
|
|
||||||
with tqdm(total=num_nodes_to_plot, desc="Generating node visualizations") as pbar:
|
|
||||||
for node_id in range(num_nodes_to_plot):
|
|
||||||
# 获取对应节点的数据
|
|
||||||
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
|
||||||
# 数据格式: [time_step, seq_len, num_node, dim]
|
|
||||||
node_pred = y_pred[:, 12, node_id, 0].cpu().numpy() # t=1时刻,指定节点,第一个特征
|
|
||||||
node_true = y_true[:, 12, node_id, 0].cpu().numpy()
|
|
||||||
else:
|
|
||||||
# 如果数据不足10个节点,只处理实际存在的节点
|
|
||||||
if node_id >= y_pred.shape[-2]:
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
node_pred = y_pred[:, 0, node_id, 0].cpu().numpy()
|
|
||||||
node_true = y_true[:, 0, node_id, 0].cpu().numpy()
|
|
||||||
|
|
||||||
# 检查数据有效性
|
|
||||||
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 取前500个时间步
|
|
||||||
max_steps = min(500, len(node_pred))
|
|
||||||
if max_steps <= 0:
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
node_pred_500 = node_pred[:max_steps]
|
|
||||||
node_true_500 = node_true[:max_steps]
|
|
||||||
|
|
||||||
# 创建时间轴
|
|
||||||
time_steps = np.arange(max_steps)
|
|
||||||
|
|
||||||
# 绘制对比图
|
|
||||||
plt.figure(figsize=(12, 6))
|
|
||||||
plt.plot(time_steps, node_true_500, 'b-', label='True Values', linewidth=2, alpha=0.8)
|
|
||||||
plt.plot(time_steps, node_pred_500, 'r-', label='Predictions', linewidth=2, alpha=0.8)
|
|
||||||
plt.xlabel('Time Steps')
|
|
||||||
plt.ylabel('Values')
|
|
||||||
plt.title(f'Node {node_id + 1}: True vs Predicted Values (First {max_steps} Time Steps)')
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# 保存图片,使用不同的命名
|
|
||||||
save_path = os.path.join(pic_dir, f'node{node_id + 1:02d}_prediction_first500.png')
|
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
# 生成所有节点的对比图(前100个时间步,便于观察)
|
|
||||||
# 选择前100个时间步
|
|
||||||
plot_steps = min(100, y_pred.shape[0])
|
|
||||||
if plot_steps <= 0:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 创建子图
|
|
||||||
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
|
||||||
axes = axes.flatten()
|
|
||||||
|
|
||||||
for node_id in range(num_nodes_to_plot):
|
|
||||||
if len(y_pred.shape) > 2 and y_pred.shape[-2] > node_id:
|
|
||||||
# 数据格式: [time_step, seq_len, num_node, dim]
|
|
||||||
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
|
||||||
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
|
||||||
else:
|
|
||||||
# 如果数据不足10个节点,只处理实际存在的节点
|
|
||||||
if node_id >= y_pred.shape[-2]:
|
|
||||||
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
|
||||||
ha='center', va='center', transform=axes[node_id].transAxes)
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
node_pred = y_pred[:plot_steps, 0, node_id, 0].cpu().numpy()
|
|
||||||
node_true = y_true[:plot_steps, 0, node_id, 0].cpu().numpy()
|
|
||||||
|
|
||||||
# 检查数据有效性
|
|
||||||
if np.isnan(node_pred).any() or np.isnan(node_true).any():
|
|
||||||
axes[node_id].text(0.5, 0.5, f'Node {node_id + 1}\nNo Data',
|
|
||||||
ha='center', va='center', transform=axes[node_id].transAxes)
|
|
||||||
continue
|
|
||||||
|
|
||||||
time_steps = np.arange(plot_steps)
|
|
||||||
|
|
||||||
axes[node_id].plot(time_steps, node_true, 'b-', label='True', linewidth=1.5, alpha=0.8)
|
|
||||||
axes[node_id].plot(time_steps, node_pred, 'r-', label='Pred', linewidth=1.5, alpha=0.8)
|
|
||||||
axes[node_id].set_title(f'Node {node_id + 1}')
|
|
||||||
axes[node_id].grid(True, alpha=0.3)
|
|
||||||
axes[node_id].legend(fontsize=8)
|
|
||||||
|
|
||||||
if node_id >= 5: # 下面一行添加x轴标签
|
|
||||||
axes[node_id].set_xlabel('Time Steps')
|
|
||||||
if node_id % 5 == 0: # 左边一列添加y轴标签
|
|
||||||
axes[node_id].set_ylabel('Values')
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
summary_path = os.path.join(pic_dir, 'all_nodes_summary.png')
|
|
||||||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _generate_input_output_comparison(y_pred, y_true, data_loader, logger, save_dir,
|
|
||||||
target_node=1, num_samples=10, scalers=None):
|
|
||||||
"""
|
|
||||||
生成输入-输出样本比较图
|
|
||||||
|
|
||||||
Args:
|
|
||||||
y_pred: 预测值
|
|
||||||
y_true: 真实值
|
|
||||||
data_loader: 数据加载器,用于获取输入数据
|
|
||||||
logger: 日志记录器
|
|
||||||
save_dir: 保存目录
|
|
||||||
target_node: 目标节点ID(从1开始)
|
|
||||||
num_samples: 要比较的样本数量
|
|
||||||
scalers: 标准化器列表,用于反归一化输入数据
|
|
||||||
"""
|
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
import numpy as np
|
|
||||||
import os
|
|
||||||
import matplotlib
|
|
||||||
from tqdm import tqdm
|
|
||||||
|
|
||||||
# 设置matplotlib配置
|
|
||||||
matplotlib.set_loglevel('error')
|
|
||||||
plt.rcParams['font.family'] = 'DejaVu Sans'
|
|
||||||
|
|
||||||
# 创建compare文件夹
|
|
||||||
compare_dir = os.path.join(save_dir, 'pic', 'compare')
|
|
||||||
os.makedirs(compare_dir, exist_ok=True)
|
|
||||||
|
|
||||||
# 获取输入数据
|
|
||||||
input_data = []
|
|
||||||
for batch_idx, (data, target) in enumerate(data_loader):
|
|
||||||
if batch_idx >= num_samples:
|
|
||||||
break
|
|
||||||
input_data.append(data.cpu().numpy())
|
|
||||||
|
|
||||||
if not input_data:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 获取目标节点的索引(从0开始)
|
|
||||||
node_idx = target_node - 1
|
|
||||||
|
|
||||||
# 检查节点索引是否有效
|
|
||||||
if node_idx >= y_pred.shape[-2]:
|
|
||||||
return
|
|
||||||
|
|
||||||
# 为每个样本生成比较图
|
|
||||||
with tqdm(total=min(num_samples, len(input_data)), desc="Generating input-output comparisons") as pbar:
|
|
||||||
for sample_idx in range(min(num_samples, len(input_data))):
|
|
||||||
# 获取输入序列(假设输入形状为 [batch, seq_len, nodes, features])
|
|
||||||
input_seq = input_data[sample_idx][0, :, node_idx, 0] # 第一个batch,所有时间步,目标节点,第一个特征
|
|
||||||
|
|
||||||
# 对输入数据进行反归一化
|
|
||||||
if scalers is not None and len(scalers) > 0:
|
|
||||||
# 使用第一个标准化器对输入进行反归一化(假设输入特征使用第一个标准化器)
|
|
||||||
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
|
||||||
|
|
||||||
# 获取对应的预测值和真实值
|
|
||||||
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy() # 所有horizon,目标节点,第一个特征
|
|
||||||
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
|
||||||
|
|
||||||
# 检查数据有效性
|
|
||||||
if (np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any()):
|
|
||||||
pbar.update(1)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 创建时间轴 - 输入和输出连续
|
|
||||||
total_time = np.arange(len(input_seq) + len(pred_seq))
|
|
||||||
|
|
||||||
# 创建合并的图形 - 输入和输出在同一个图中
|
|
||||||
plt.figure(figsize=(14, 8))
|
|
||||||
|
|
||||||
# 绘制完整的真实值曲线(输入 + 真实输出)
|
|
||||||
true_combined = np.concatenate([input_seq, true_seq])
|
|
||||||
plt.plot(total_time, true_combined, 'b', label='True Values (Input + Output)',
|
|
||||||
linewidth=2.5, alpha=0.9, linestyle='-')
|
|
||||||
|
|
||||||
# 绘制预测值曲线(只绘制输出部分)
|
|
||||||
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
|
||||||
plt.plot(output_time, pred_seq, 'r', label='Predicted Values',
|
|
||||||
linewidth=2, alpha=0.8, linestyle='-')
|
|
||||||
|
|
||||||
# 添加垂直线分隔输入和输出
|
|
||||||
plt.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.7,
|
|
||||||
label='Input/Output Boundary')
|
|
||||||
|
|
||||||
# 设置图形属性
|
|
||||||
plt.xlabel('Time Steps')
|
|
||||||
plt.ylabel('Values')
|
|
||||||
plt.title(f'Sample {sample_idx + 1}: Input-Output Comparison (Node {target_node})')
|
|
||||||
plt.legend()
|
|
||||||
plt.grid(True, alpha=0.3)
|
|
||||||
|
|
||||||
# 调整布局
|
|
||||||
plt.tight_layout()
|
|
||||||
|
|
||||||
# 保存图片
|
|
||||||
save_path = os.path.join(compare_dir, f'sample{sample_idx + 1:02d}_node{target_node:02d}_comparison.png')
|
|
||||||
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
pbar.update(1)
|
|
||||||
|
|
||||||
# 生成汇总图(所有样本的预测值对比)
|
|
||||||
|
|
||||||
fig, axes = plt.subplots(2, 5, figsize=(20, 8))
|
|
||||||
axes = axes.flatten()
|
|
||||||
|
|
||||||
for sample_idx in range(min(num_samples, len(input_data))):
|
|
||||||
if sample_idx >= 10: # 最多显示10个子图
|
|
||||||
break
|
|
||||||
|
|
||||||
ax = axes[sample_idx]
|
|
||||||
|
|
||||||
# 获取输入序列和预测值、真实值
|
|
||||||
input_seq = input_data[sample_idx][0, :, node_idx, 0]
|
|
||||||
if scalers is not None and len(scalers) > 0:
|
|
||||||
input_seq = scalers[0].inverse_transform(input_seq.reshape(-1, 1)).flatten()
|
|
||||||
|
|
||||||
pred_seq = y_pred[sample_idx, :, node_idx, 0].cpu().numpy()
|
|
||||||
true_seq = y_true[sample_idx, :, node_idx, 0].cpu().numpy()
|
|
||||||
|
|
||||||
# 检查数据有效性
|
|
||||||
if np.isnan(input_seq).any() or np.isnan(pred_seq).any() or np.isnan(true_seq).any():
|
|
||||||
ax.text(0.5, 0.5, f'Sample {sample_idx + 1}\nNo Data',
|
|
||||||
ha='center', va='center', transform=ax.transAxes)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 绘制对比图 - 输入和输出连续显示
|
|
||||||
total_time = np.arange(len(input_seq) + len(pred_seq))
|
|
||||||
true_combined = np.concatenate([input_seq, true_seq])
|
|
||||||
output_time = np.arange(len(input_seq), len(input_seq) + len(pred_seq))
|
|
||||||
|
|
||||||
ax.plot(total_time, true_combined, 'b', label='True', linewidth=2, alpha=0.9, linestyle='-')
|
|
||||||
ax.plot(output_time, pred_seq, 'r', label='Pred', linewidth=1.5, alpha=0.8, linestyle='-')
|
|
||||||
ax.axvline(x=len(input_seq)-0.5, color='gray', linestyle=':', alpha=0.5)
|
|
||||||
ax.set_title(f'Sample {sample_idx + 1}')
|
|
||||||
ax.grid(True, alpha=0.3)
|
|
||||||
ax.legend(fontsize=8)
|
|
||||||
|
|
||||||
if sample_idx >= 5: # 下面一行添加x轴标签
|
|
||||||
ax.set_xlabel('Time Steps')
|
|
||||||
if sample_idx % 5 == 0: # 左边一列添加y轴标签
|
|
||||||
ax.set_ylabel('Values')
|
|
||||||
|
|
||||||
# 隐藏多余的子图
|
|
||||||
for i in range(min(num_samples, len(input_data)), 10):
|
|
||||||
axes[i].set_visible(False)
|
|
||||||
|
|
||||||
plt.tight_layout()
|
|
||||||
summary_path = os.path.join(compare_dir, f'all_samples_node{target_node:02d}_summary.png')
|
|
||||||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
|
||||||
plt.close()
|
|
||||||
|
|
||||||
def _save_nfe_data(self):
|
|
||||||
"""保存nfe数据到文件"""
|
|
||||||
if not self.res:
|
|
||||||
return
|
|
||||||
|
|
||||||
res = pd.concat(self.res, keys=self.keys)
|
|
||||||
res.index.names = ['epoch', 'iter']
|
|
||||||
|
|
||||||
# 获取模型配置参数
|
|
||||||
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
|
||||||
atol = getattr(self.model, 'atol', 1e-5)
|
|
||||||
rtol = getattr(self.model, 'rtol', 1e-5)
|
|
||||||
|
|
||||||
# 保存nfe数据
|
|
||||||
nfe_file = os.path.join(
|
|
||||||
self.logger.dir_path,
|
|
||||||
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
|
||||||
res.to_pickle(nfe_file)
|
|
||||||
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _compute_sampling_threshold(global_step, k):
|
|
||||||
return k / (k + math.exp(global_step / k))
|
|
||||||
|
|
@ -1,14 +1,11 @@
|
||||||
from trainer.trainer import Trainer
|
from trainer.trainer import Trainer
|
||||||
from trainer.ode_trainer import Trainer as ode_trainer
|
|
||||||
|
|
||||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||||
lr_scheduler, kwargs):
|
lr_scheduler, kwargs):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
selected_Trainer = None
|
selected_Trainer = None
|
||||||
match model_name:
|
match model_name:
|
||||||
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
|
||||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
|
||||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
||||||
if selected_Trainer is None: raise NotImplementedError
|
if selected_Trainer is None: raise NotImplementedError
|
||||||
return selected_Trainer
|
return selected_Trainer
|
||||||
Loading…
Reference in New Issue