更新项目文件:添加数据选择器、模型选择器、训练器选择器和ODE训练器
This commit is contained in:
parent
df8c573f4c
commit
66a23ffbbb
|
|
@ -160,3 +160,4 @@ cython_debug/
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
|
STDEN/
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
# 默认忽略的文件
|
||||||
|
/shelf/
|
||||||
|
/workspace.xml
|
||||||
|
# 基于编辑器的 HTTP 客户端请求
|
||||||
|
/httpRequests/
|
||||||
|
# Datasource local storage ignored files
|
||||||
|
/dataSources/
|
||||||
|
/dataSources.local.xml
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<module type="PYTHON_MODULE" version="4">
|
||||||
|
<component name="NewModuleRootManager">
|
||||||
|
<content url="file://$MODULE_DIR$" />
|
||||||
|
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||||
|
<orderEntry type="sourceFolder" forTests="false" />
|
||||||
|
</component>
|
||||||
|
<component name="PyDocumentationSettings">
|
||||||
|
<option name="format" value="PLAIN" />
|
||||||
|
<option name="myDocStringFormat" value="Plain" />
|
||||||
|
</component>
|
||||||
|
</module>
|
||||||
|
|
@ -0,0 +1,27 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<profile version="1.0">
|
||||||
|
<option name="myName" value="Project Default" />
|
||||||
|
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||||
|
<Languages>
|
||||||
|
<language minSize="136" name="Python" />
|
||||||
|
</Languages>
|
||||||
|
</inspection_tool>
|
||||||
|
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||||
|
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||||
|
<option name="ignoredPackages">
|
||||||
|
<value>
|
||||||
|
<list size="8">
|
||||||
|
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
||||||
|
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
||||||
|
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
||||||
|
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
||||||
|
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
||||||
|
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
||||||
|
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||||
|
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
||||||
|
</list>
|
||||||
|
</value>
|
||||||
|
</option>
|
||||||
|
</inspection_tool>
|
||||||
|
</profile>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,6 @@
|
||||||
|
<component name="InspectionProjectProfileManager">
|
||||||
|
<settings>
|
||||||
|
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||||
|
<version value="1.0" />
|
||||||
|
</settings>
|
||||||
|
</component>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="Black">
|
||||||
|
<option name="sdkName" value="Python 3.10" />
|
||||||
|
</component>
|
||||||
|
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="ProjectModuleManager">
|
||||||
|
<modules>
|
||||||
|
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
||||||
|
</modules>
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
<?xml version="1.0" encoding="UTF-8"?>
|
||||||
|
<project version="4">
|
||||||
|
<component name="VcsDirectoryMappings">
|
||||||
|
<mapping directory="" vcs="Git" />
|
||||||
|
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||||
|
</component>
|
||||||
|
</project>
|
||||||
|
|
@ -0,0 +1,296 @@
|
||||||
|
from,to,cost
|
||||||
|
9,153,310.6
|
||||||
|
153,62,330.9
|
||||||
|
62,111,332.9
|
||||||
|
111,11,324.2
|
||||||
|
11,28,336.0
|
||||||
|
28,169,133.7
|
||||||
|
138,135,354.7
|
||||||
|
135,133,387.9
|
||||||
|
133,163,337.1
|
||||||
|
163,20,352.0
|
||||||
|
20,19,420.8
|
||||||
|
19,14,351.3
|
||||||
|
14,39,340.2
|
||||||
|
39,164,350.3
|
||||||
|
164,167,365.2
|
||||||
|
167,70,359.0
|
||||||
|
70,59,388.2
|
||||||
|
59,58,305.7
|
||||||
|
58,67,294.4
|
||||||
|
67,66,299.5
|
||||||
|
66,55,313.3
|
||||||
|
55,53,332.1
|
||||||
|
53,150,278.9
|
||||||
|
150,61,308.4
|
||||||
|
61,64,311.4
|
||||||
|
64,63,243.6
|
||||||
|
47,65,372.8
|
||||||
|
65,48,319.4
|
||||||
|
48,49,309.7
|
||||||
|
49,54,320.5
|
||||||
|
54,56,318.3
|
||||||
|
56,57,297.9
|
||||||
|
57,68,293.5
|
||||||
|
68,69,342.5
|
||||||
|
69,60,318.0
|
||||||
|
60,17,305.9
|
||||||
|
17,5,321.4
|
||||||
|
5,18,402.2
|
||||||
|
18,22,447.4
|
||||||
|
22,30,377.5
|
||||||
|
30,29,417.7
|
||||||
|
29,21,360.8
|
||||||
|
21,132,407.6
|
||||||
|
132,134,386.9
|
||||||
|
134,136,350.2
|
||||||
|
123,121,326.3
|
||||||
|
121,140,385.2
|
||||||
|
140,118,393.0
|
||||||
|
118,96,296.7
|
||||||
|
96,94,398.2
|
||||||
|
94,86,337.1
|
||||||
|
86,78,473.8
|
||||||
|
78,46,353.4
|
||||||
|
46,152,385.7
|
||||||
|
152,157,350.0
|
||||||
|
157,35,354.4
|
||||||
|
35,77,356.1
|
||||||
|
77,52,354.2
|
||||||
|
52,3,357.8
|
||||||
|
3,16,382.4
|
||||||
|
16,0,55.7
|
||||||
|
42,12,335.1
|
||||||
|
12,139,328.8
|
||||||
|
139,168,412.6
|
||||||
|
168,154,337.3
|
||||||
|
154,143,370.7
|
||||||
|
143,10,6.3
|
||||||
|
107,105,354.6
|
||||||
|
105,104,386.9
|
||||||
|
104,148,362.1
|
||||||
|
148,97,316.3
|
||||||
|
97,101,380.7
|
||||||
|
101,137,361.4
|
||||||
|
137,102,365.5
|
||||||
|
102,24,375.5
|
||||||
|
24,166,312.2
|
||||||
|
129,156,256.1
|
||||||
|
156,33,329.1
|
||||||
|
33,32,356.5
|
||||||
|
91,89,405.6
|
||||||
|
89,147,347.0
|
||||||
|
147,15,351.7
|
||||||
|
15,44,339.5
|
||||||
|
44,41,350.8
|
||||||
|
41,43,322.6
|
||||||
|
43,100,338.9
|
||||||
|
100,83,347.9
|
||||||
|
83,87,327.2
|
||||||
|
87,88,321.0
|
||||||
|
88,75,335.8
|
||||||
|
75,51,384.8
|
||||||
|
51,73,391.1
|
||||||
|
73,71,289.3
|
||||||
|
31,155,260.0
|
||||||
|
155,34,320.4
|
||||||
|
34,128,393.3
|
||||||
|
145,115,399.4
|
||||||
|
115,112,328.1
|
||||||
|
112,8,469.4
|
||||||
|
8,117,816.2
|
||||||
|
117,125,397.1
|
||||||
|
125,127,372.7
|
||||||
|
127,109,380.5
|
||||||
|
109,161,355.5
|
||||||
|
161,110,367.7
|
||||||
|
110,160,102.0
|
||||||
|
72,159,342.9
|
||||||
|
159,50,383.3
|
||||||
|
50,74,354.1
|
||||||
|
74,82,350.2
|
||||||
|
82,81,335.4
|
||||||
|
81,99,391.6
|
||||||
|
99,84,354.9
|
||||||
|
84,13,306.4
|
||||||
|
13,40,327.4
|
||||||
|
40,162,413.9
|
||||||
|
162,108,301.9
|
||||||
|
108,146,317.8
|
||||||
|
146,85,376.6
|
||||||
|
85,90,347.0
|
||||||
|
26,27,341.6
|
||||||
|
27,6,359.4
|
||||||
|
6,149,417.8
|
||||||
|
149,126,388.0
|
||||||
|
126,124,384.3
|
||||||
|
124,7,763.3
|
||||||
|
7,114,323.1
|
||||||
|
114,113,351.6
|
||||||
|
113,116,411.9
|
||||||
|
116,144,262.0
|
||||||
|
25,103,350.2
|
||||||
|
103,23,376.3
|
||||||
|
23,165,396.4
|
||||||
|
165,38,381.0
|
||||||
|
38,92,368.0
|
||||||
|
92,37,336.3
|
||||||
|
37,130,357.8
|
||||||
|
130,106,532.3
|
||||||
|
106,131,166.5
|
||||||
|
1,2,371.6
|
||||||
|
2,4,338.1
|
||||||
|
4,76,429.0
|
||||||
|
76,36,366.1
|
||||||
|
36,158,344.5
|
||||||
|
158,151,350.1
|
||||||
|
151,45,358.8
|
||||||
|
45,93,340.9
|
||||||
|
93,80,329.9
|
||||||
|
80,79,384.1
|
||||||
|
79,95,335.7
|
||||||
|
95,98,320.9
|
||||||
|
98,119,340.3
|
||||||
|
119,120,376.8
|
||||||
|
120,122,393.1
|
||||||
|
122,141,428.7
|
||||||
|
141,142,359.3
|
||||||
|
30,165,379.6
|
||||||
|
165,29,41.7
|
||||||
|
29,38,343.3
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
17,153,375.6
|
||||||
|
153,5,256.3
|
||||||
|
153,62,330.9
|
||||||
|
18,6,499.4
|
||||||
|
6,22,254.0
|
||||||
|
22,149,185.4
|
||||||
|
22,4,257.9
|
||||||
|
4,30,236.8
|
||||||
|
30,76,307.0
|
||||||
|
95,98,320.9
|
||||||
|
98,144,45.1
|
||||||
|
45,93,340.9
|
||||||
|
93,106,112.2
|
||||||
|
162,151,113.6
|
||||||
|
151,108,192.9
|
||||||
|
108,45,359.8
|
||||||
|
146,92,311.2
|
||||||
|
92,85,343.9
|
||||||
|
85,37,373.2
|
||||||
|
13,169,326.2
|
||||||
|
169,40,96.1
|
||||||
|
124,13,460.7
|
||||||
|
13,7,305.5
|
||||||
|
7,40,624.1
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
90,132,152.2
|
||||||
|
26,32,106.7
|
||||||
|
9,129,148.3
|
||||||
|
129,153,219.6
|
||||||
|
31,26,116.0
|
||||||
|
26,155,270.7
|
||||||
|
9,128,142.2
|
||||||
|
128,153,215.0
|
||||||
|
153,167,269.7
|
||||||
|
167,62,64.8
|
||||||
|
62,70,332.6
|
||||||
|
124,169,145.2
|
||||||
|
169,7,631.5
|
||||||
|
44,169,397.8
|
||||||
|
169,41,124.0
|
||||||
|
44,124,375.7
|
||||||
|
124,41,243.9
|
||||||
|
41,7,519.4
|
||||||
|
6,14,289.3
|
||||||
|
14,149,259.0
|
||||||
|
149,39,206.9
|
||||||
|
144,98,45.1
|
||||||
|
19,4,326.8
|
||||||
|
4,14,178.6
|
||||||
|
14,76,299.0
|
||||||
|
15,151,136.4
|
||||||
|
151,44,203.1
|
||||||
|
45,106,260.6
|
||||||
|
106,93,112.2
|
||||||
|
20,165,132.5
|
||||||
|
165,19,289.2
|
||||||
|
89,92,323.2
|
||||||
|
92,147,321.9
|
||||||
|
147,37,48.2
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
104,147,277.5
|
||||||
|
147,148,84.7
|
||||||
|
20,101,201.2
|
||||||
|
101,19,534.4
|
||||||
|
19,137,245.5
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
44,42,342.3
|
||||||
|
42,41,102.5
|
||||||
|
44,8,789.1
|
||||||
|
8,41,657.4
|
||||||
|
41,117,160.5
|
||||||
|
168,167,172.4
|
||||||
|
167,154,165.2
|
||||||
|
143,128,81.9
|
||||||
|
128,10,88.2
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
15,152,135.0
|
||||||
|
152,44,204.6
|
||||||
|
19,77,320.7
|
||||||
|
77,14,299.8
|
||||||
|
14,52,127.6
|
||||||
|
14,127,314.8
|
||||||
|
127,39,280.4
|
||||||
|
39,109,237.0
|
||||||
|
31,160,116.5
|
||||||
|
160,155,272.4
|
||||||
|
133,91,152.8
|
||||||
|
91,163,313.6
|
||||||
|
150,71,221.1
|
||||||
|
71,61,89.6
|
||||||
|
32,160,107.7
|
||||||
|
72,162,3274.4
|
||||||
|
162,13,554.5
|
||||||
|
162,40,413.9
|
||||||
|
65,72,297.9
|
||||||
|
72,48,21.5
|
||||||
|
13,42,319.8
|
||||||
|
42,40,40.7
|
||||||
|
8,42,759.5
|
||||||
|
42,117,58.9
|
||||||
|
8,13,450.3
|
||||||
|
13,117,378.5
|
||||||
|
117,40,64.0
|
||||||
|
46,162,391.6
|
||||||
|
162,152,115.3
|
||||||
|
152,108,191.4
|
||||||
|
104,108,375.9
|
||||||
|
108,148,311.6
|
||||||
|
148,146,80.0
|
||||||
|
21,90,396.9
|
||||||
|
90,132,152.2
|
||||||
|
101,29,252.3
|
||||||
|
29,137,110.7
|
||||||
|
77,22,353.8
|
||||||
|
22,52,227.8
|
||||||
|
52,30,186.6
|
||||||
|
127,18,425.2
|
||||||
|
18,109,439.1
|
||||||
|
109,22,135.5
|
||||||
|
168,17,232.7
|
||||||
|
17,154,294.2
|
||||||
|
154,5,166.3
|
||||||
|
78,107,143.9
|
||||||
|
107,46,236.3
|
||||||
|
118,145,250.6
|
||||||
|
145,96,85.1
|
||||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,6 +2,7 @@ import numpy as np
|
||||||
import os
|
import os
|
||||||
|
|
||||||
def load_dataset(config):
|
def load_dataset(config):
|
||||||
|
|
||||||
dataset_name = config['basic']['dataset']
|
dataset_name = config['basic']['dataset']
|
||||||
node_num = config['data']['num_nodes']
|
node_num = config['data']['num_nodes']
|
||||||
input_dim = config['data']['input_dim']
|
input_dim = config['data']['input_dim']
|
||||||
|
|
@ -10,4 +11,8 @@ def load_dataset(config):
|
||||||
case 'EcoSolar':
|
case 'EcoSolar':
|
||||||
data_path = os.path.join('./data/EcoSolar.npy')
|
data_path = os.path.join('./data/EcoSolar.npy')
|
||||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||||
|
case 'PEMS08':
|
||||||
|
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||||
|
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
@ -0,0 +1,9 @@
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
def load_graph(config):
|
||||||
|
dataset_path = config['data']['graph_pkl_filename']
|
||||||
|
graph = np.load(dataset_path)
|
||||||
|
# 将inf值填充为0
|
||||||
|
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
||||||
|
|
||||||
|
return graph
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 100
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 100
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
|
@ -0,0 +1,62 @@
|
||||||
|
basic:
|
||||||
|
dataset: PEMS08
|
||||||
|
device: cuda:0
|
||||||
|
mode: train
|
||||||
|
model: STDEN
|
||||||
|
seed: 2025
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
batch_size: 32
|
||||||
|
column_wise: false
|
||||||
|
dataset_dir: data/PEMS08
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||||
|
horizon: 24
|
||||||
|
input_dim: 1
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 170
|
||||||
|
steps_per_day: 24
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_batch_size: 32
|
||||||
|
val_ratio: 0.2
|
||||||
|
model:
|
||||||
|
filter_type: default
|
||||||
|
gcn_step: 2
|
||||||
|
horizon: 12
|
||||||
|
input_dim: 1
|
||||||
|
l1_decay: 0
|
||||||
|
latent_dim: 4
|
||||||
|
n_traj_samples: 3
|
||||||
|
nfe: false
|
||||||
|
num_rnn_layers: 1
|
||||||
|
ode_method: dopri5
|
||||||
|
odeint_atol: 1.0e-05
|
||||||
|
odeint_rtol: 1.0e-05
|
||||||
|
output_dim: 1
|
||||||
|
recg_type: gru
|
||||||
|
rnn_units: 64
|
||||||
|
save_latent: false
|
||||||
|
seq_len: 12
|
||||||
|
train:
|
||||||
|
batch_size: 64
|
||||||
|
debug: false
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
|
||||||
|
loss: mae
|
||||||
|
lr_decay: false
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: 5,20,40,70
|
||||||
|
lr_init: 0.003
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
max_grad_norm: 5
|
||||||
|
output_dim: 1
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
2
main.py
2
main.py
|
|
@ -3,8 +3,6 @@
|
||||||
时空数据深度学习预测项目主程序
|
时空数据深度学习预测项目主程序
|
||||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from utils.args_reader import config_loader
|
from utils.args_reader import config_loader
|
||||||
import utils.init as init
|
import utils.init as init
|
||||||
import torch
|
import torch
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
from models.STDEN.stden_model import STDENModel
|
||||||
|
|
||||||
def model_selector(config):
|
def model_selector(config):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
model = None
|
model = None
|
||||||
|
match model_name:
|
||||||
|
case 'STDEN': model = STDENModel(config)
|
||||||
return model
|
return model
|
||||||
|
|
@ -2,6 +2,8 @@ import math
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
import copy
|
import copy
|
||||||
|
import pandas as pd
|
||||||
|
import numpy as np
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -26,6 +28,10 @@ class Trainer:
|
||||||
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
||||||
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
self.c = []
|
||||||
|
self.res, self.keys = [], []
|
||||||
|
|
||||||
def _run_epoch(self, epoch, dataloader, mode):
|
def _run_epoch(self, epoch, dataloader, mode):
|
||||||
if mode == 'train':
|
if mode == 'train':
|
||||||
self.model.train()
|
self.model.train()
|
||||||
|
|
@ -37,17 +43,28 @@ class Trainer:
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
epoch_time = time.time()
|
epoch_time = time.time()
|
||||||
|
|
||||||
|
# 清空nfe数据收集
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.clear()
|
||||||
|
|
||||||
with torch.set_grad_enabled(optimizer_step):
|
with torch.set_grad_enabled(optimizer_step):
|
||||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||||
for batch_idx, (data, target) in enumerate(dataloader):
|
for batch_idx, (data, target) in enumerate(dataloader):
|
||||||
label = target[..., :self.args['output_dim']]
|
label = target[..., :self.args['output_dim']]
|
||||||
output = self.model(data).to(self.args['device'])
|
output, fe = self.model(data)
|
||||||
|
|
||||||
if self.args['real_value']:
|
if self.args['real_value']:
|
||||||
# 只对输出维度进行反归一化
|
# 只对输出维度进行反归一化
|
||||||
output = self._inverse_transform_output(output)
|
output = self._inverse_transform_output(output)
|
||||||
|
|
||||||
loss = self.loss(output, label)
|
loss = self.loss(output, label)
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.c.append([*fe, loss.item()])
|
||||||
|
# 记录FE信息
|
||||||
|
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||||
|
|
||||||
if optimizer_step and self.optimizer is not None:
|
if optimizer_step and self.optimizer is not None:
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
@ -69,6 +86,12 @@ class Trainer:
|
||||||
avg_loss = total_loss / len(dataloader)
|
avg_loss = total_loss / len(dataloader)
|
||||||
self.logger.logger.info(
|
self.logger.logger.info(
|
||||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||||
|
|
||||||
|
# 收集nfe数据(仅在训练模式下)
|
||||||
|
if mode == 'train':
|
||||||
|
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
||||||
|
self.keys.append(epoch)
|
||||||
|
|
||||||
return avg_loss
|
return avg_loss
|
||||||
|
|
||||||
def _inverse_transform_output(self, output):
|
def _inverse_transform_output(self, output):
|
||||||
|
|
@ -143,6 +166,10 @@ class Trainer:
|
||||||
best_test_model = copy.deepcopy(self.model.state_dict())
|
best_test_model = copy.deepcopy(self.model.state_dict())
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
||||||
|
self._save_nfe_data()
|
||||||
|
|
||||||
if not self.args['debug']:
|
if not self.args['debug']:
|
||||||
torch.save(best_model, self.best_path)
|
torch.save(best_model, self.best_path)
|
||||||
torch.save(best_test_model, self.best_test_path)
|
torch.save(best_test_model, self.best_test_path)
|
||||||
|
|
@ -164,18 +191,24 @@ class Trainer:
|
||||||
if path:
|
if path:
|
||||||
checkpoint = torch.load(path)
|
checkpoint = torch.load(path)
|
||||||
model.load_state_dict(checkpoint['state_dict'])
|
model.load_state_dict(checkpoint['state_dict'])
|
||||||
model.to(args.device)
|
model.to(args['device'])
|
||||||
|
|
||||||
model.eval()
|
model.eval()
|
||||||
y_pred, y_true = [], []
|
y_pred, y_true = [], []
|
||||||
|
|
||||||
|
# 用于收集nfe数据
|
||||||
|
c = []
|
||||||
|
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for data, target in data_loader:
|
for data, target in data_loader:
|
||||||
label = target[..., :args['output_dim']]
|
label = target[..., :args['output_dim']]
|
||||||
output = model(data)
|
output, fe = model(data)
|
||||||
y_pred.append(output)
|
y_pred.append(output)
|
||||||
y_true.append(label)
|
y_true.append(label)
|
||||||
|
|
||||||
|
# 收集nfe数据
|
||||||
|
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
||||||
|
|
||||||
if args['real_value']:
|
if args['real_value']:
|
||||||
# 只对输出维度进行反归一化
|
# 只对输出维度进行反归一化
|
||||||
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
||||||
|
|
@ -192,6 +225,10 @@ class Trainer:
|
||||||
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||||
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
|
|
||||||
|
# 保存nfe数据(如果启用)
|
||||||
|
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
||||||
|
Trainer._save_nfe_data_static(c, model, logger)
|
||||||
|
|
||||||
# 只在需要时生成可视化图片
|
# 只在需要时生成可视化图片
|
||||||
if generate_viz:
|
if generate_viz:
|
||||||
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
||||||
|
|
@ -514,6 +551,26 @@ class Trainer:
|
||||||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||||
plt.close()
|
plt.close()
|
||||||
|
|
||||||
|
def _save_nfe_data(self):
|
||||||
|
"""保存nfe数据到文件"""
|
||||||
|
if not self.res:
|
||||||
|
return
|
||||||
|
|
||||||
|
res = pd.concat(self.res, keys=self.keys)
|
||||||
|
res.index.names = ['epoch', 'iter']
|
||||||
|
|
||||||
|
# 获取模型配置参数
|
||||||
|
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
||||||
|
atol = getattr(self.model, 'atol', 1e-5)
|
||||||
|
rtol = getattr(self.model, 'rtol', 1e-5)
|
||||||
|
|
||||||
|
# 保存nfe数据
|
||||||
|
nfe_file = os.path.join(
|
||||||
|
self.logger.dir_path,
|
||||||
|
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||||
|
res.to_pickle(nfe_file)
|
||||||
|
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_sampling_threshold(global_step, k):
|
def _compute_sampling_threshold(global_step, k):
|
||||||
return k / (k + math.exp(global_step / k))
|
return k / (k + math.exp(global_step / k))
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,14 @@
|
||||||
from trainer.trainer import Trainer
|
from trainer.trainer import Trainer
|
||||||
|
from trainer.ode_trainer import Trainer as ode_trainer
|
||||||
|
|
||||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||||
lr_scheduler, kwargs):
|
lr_scheduler, kwargs):
|
||||||
model_name = config['basic']['model']
|
model_name = config['basic']['model']
|
||||||
selected_Trainer = None
|
selected_Trainer = None
|
||||||
match model_name:
|
match model_name:
|
||||||
|
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
||||||
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||||
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||||
if selected_Trainer is None: raise NotImplementedError
|
if selected_Trainer is None: raise NotImplementedError
|
||||||
return selected_Trainer
|
return selected_Trainer
|
||||||
Loading…
Reference in New Issue