更新项目文件:添加数据选择器、模型选择器、训练器选择器和ODE训练器
This commit is contained in:
parent
df8c573f4c
commit
66a23ffbbb
|
|
@ -160,3 +160,4 @@ cython_debug/
|
|||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
STDEN/
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
# 默认忽略的文件
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# 基于编辑器的 HTTP 客户端请求
|
||||
/httpRequests/
|
||||
# Datasource local storage ignored files
|
||||
/dataSources/
|
||||
/dataSources.local.xml
|
||||
|
|
@ -0,0 +1,12 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="TS" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
|
|
@ -0,0 +1,27 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<profile version="1.0">
|
||||
<option name="myName" value="Project Default" />
|
||||
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
||||
<Languages>
|
||||
<language minSize="136" name="Python" />
|
||||
</Languages>
|
||||
</inspection_tool>
|
||||
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
||||
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
||||
<option name="ignoredPackages">
|
||||
<value>
|
||||
<list size="8">
|
||||
<item index="0" class="java.lang.String" itemvalue="argparse" />
|
||||
<item index="1" class="java.lang.String" itemvalue="torch_summary" />
|
||||
<item index="2" class="java.lang.String" itemvalue="positional_encodings" />
|
||||
<item index="3" class="java.lang.String" itemvalue="scikit_learn" />
|
||||
<item index="4" class="java.lang.String" itemvalue="easy_torch" />
|
||||
<item index="5" class="java.lang.String" itemvalue="setuptools" />
|
||||
<item index="6" class="java.lang.String" itemvalue="numpy" />
|
||||
<item index="7" class="java.lang.String" itemvalue="openpyxl" />
|
||||
</list>
|
||||
</value>
|
||||
</option>
|
||||
</inspection_tool>
|
||||
</profile>
|
||||
</component>
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.10" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="TS" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/Project-I.iml" filepath="$PROJECT_DIR$/.idea/Project-I.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="" vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/STDEN" vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
|
|
@ -0,0 +1,296 @@
|
|||
from,to,cost
|
||||
9,153,310.6
|
||||
153,62,330.9
|
||||
62,111,332.9
|
||||
111,11,324.2
|
||||
11,28,336.0
|
||||
28,169,133.7
|
||||
138,135,354.7
|
||||
135,133,387.9
|
||||
133,163,337.1
|
||||
163,20,352.0
|
||||
20,19,420.8
|
||||
19,14,351.3
|
||||
14,39,340.2
|
||||
39,164,350.3
|
||||
164,167,365.2
|
||||
167,70,359.0
|
||||
70,59,388.2
|
||||
59,58,305.7
|
||||
58,67,294.4
|
||||
67,66,299.5
|
||||
66,55,313.3
|
||||
55,53,332.1
|
||||
53,150,278.9
|
||||
150,61,308.4
|
||||
61,64,311.4
|
||||
64,63,243.6
|
||||
47,65,372.8
|
||||
65,48,319.4
|
||||
48,49,309.7
|
||||
49,54,320.5
|
||||
54,56,318.3
|
||||
56,57,297.9
|
||||
57,68,293.5
|
||||
68,69,342.5
|
||||
69,60,318.0
|
||||
60,17,305.9
|
||||
17,5,321.4
|
||||
5,18,402.2
|
||||
18,22,447.4
|
||||
22,30,377.5
|
||||
30,29,417.7
|
||||
29,21,360.8
|
||||
21,132,407.6
|
||||
132,134,386.9
|
||||
134,136,350.2
|
||||
123,121,326.3
|
||||
121,140,385.2
|
||||
140,118,393.0
|
||||
118,96,296.7
|
||||
96,94,398.2
|
||||
94,86,337.1
|
||||
86,78,473.8
|
||||
78,46,353.4
|
||||
46,152,385.7
|
||||
152,157,350.0
|
||||
157,35,354.4
|
||||
35,77,356.1
|
||||
77,52,354.2
|
||||
52,3,357.8
|
||||
3,16,382.4
|
||||
16,0,55.7
|
||||
42,12,335.1
|
||||
12,139,328.8
|
||||
139,168,412.6
|
||||
168,154,337.3
|
||||
154,143,370.7
|
||||
143,10,6.3
|
||||
107,105,354.6
|
||||
105,104,386.9
|
||||
104,148,362.1
|
||||
148,97,316.3
|
||||
97,101,380.7
|
||||
101,137,361.4
|
||||
137,102,365.5
|
||||
102,24,375.5
|
||||
24,166,312.2
|
||||
129,156,256.1
|
||||
156,33,329.1
|
||||
33,32,356.5
|
||||
91,89,405.6
|
||||
89,147,347.0
|
||||
147,15,351.7
|
||||
15,44,339.5
|
||||
44,41,350.8
|
||||
41,43,322.6
|
||||
43,100,338.9
|
||||
100,83,347.9
|
||||
83,87,327.2
|
||||
87,88,321.0
|
||||
88,75,335.8
|
||||
75,51,384.8
|
||||
51,73,391.1
|
||||
73,71,289.3
|
||||
31,155,260.0
|
||||
155,34,320.4
|
||||
34,128,393.3
|
||||
145,115,399.4
|
||||
115,112,328.1
|
||||
112,8,469.4
|
||||
8,117,816.2
|
||||
117,125,397.1
|
||||
125,127,372.7
|
||||
127,109,380.5
|
||||
109,161,355.5
|
||||
161,110,367.7
|
||||
110,160,102.0
|
||||
72,159,342.9
|
||||
159,50,383.3
|
||||
50,74,354.1
|
||||
74,82,350.2
|
||||
82,81,335.4
|
||||
81,99,391.6
|
||||
99,84,354.9
|
||||
84,13,306.4
|
||||
13,40,327.4
|
||||
40,162,413.9
|
||||
162,108,301.9
|
||||
108,146,317.8
|
||||
146,85,376.6
|
||||
85,90,347.0
|
||||
26,27,341.6
|
||||
27,6,359.4
|
||||
6,149,417.8
|
||||
149,126,388.0
|
||||
126,124,384.3
|
||||
124,7,763.3
|
||||
7,114,323.1
|
||||
114,113,351.6
|
||||
113,116,411.9
|
||||
116,144,262.0
|
||||
25,103,350.2
|
||||
103,23,376.3
|
||||
23,165,396.4
|
||||
165,38,381.0
|
||||
38,92,368.0
|
||||
92,37,336.3
|
||||
37,130,357.8
|
||||
130,106,532.3
|
||||
106,131,166.5
|
||||
1,2,371.6
|
||||
2,4,338.1
|
||||
4,76,429.0
|
||||
76,36,366.1
|
||||
36,158,344.5
|
||||
158,151,350.1
|
||||
151,45,358.8
|
||||
45,93,340.9
|
||||
93,80,329.9
|
||||
80,79,384.1
|
||||
79,95,335.7
|
||||
95,98,320.9
|
||||
98,119,340.3
|
||||
119,120,376.8
|
||||
120,122,393.1
|
||||
122,141,428.7
|
||||
141,142,359.3
|
||||
30,165,379.6
|
||||
165,29,41.7
|
||||
29,38,343.3
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
17,153,375.6
|
||||
153,5,256.3
|
||||
153,62,330.9
|
||||
18,6,499.4
|
||||
6,22,254.0
|
||||
22,149,185.4
|
||||
22,4,257.9
|
||||
4,30,236.8
|
||||
30,76,307.0
|
||||
95,98,320.9
|
||||
98,144,45.1
|
||||
45,93,340.9
|
||||
93,106,112.2
|
||||
162,151,113.6
|
||||
151,108,192.9
|
||||
108,45,359.8
|
||||
146,92,311.2
|
||||
92,85,343.9
|
||||
85,37,373.2
|
||||
13,169,326.2
|
||||
169,40,96.1
|
||||
124,13,460.7
|
||||
13,7,305.5
|
||||
7,40,624.1
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
90,132,152.2
|
||||
26,32,106.7
|
||||
9,129,148.3
|
||||
129,153,219.6
|
||||
31,26,116.0
|
||||
26,155,270.7
|
||||
9,128,142.2
|
||||
128,153,215.0
|
||||
153,167,269.7
|
||||
167,62,64.8
|
||||
62,70,332.6
|
||||
124,169,145.2
|
||||
169,7,631.5
|
||||
44,169,397.8
|
||||
169,41,124.0
|
||||
44,124,375.7
|
||||
124,41,243.9
|
||||
41,7,519.4
|
||||
6,14,289.3
|
||||
14,149,259.0
|
||||
149,39,206.9
|
||||
144,98,45.1
|
||||
19,4,326.8
|
||||
4,14,178.6
|
||||
14,76,299.0
|
||||
15,151,136.4
|
||||
151,44,203.1
|
||||
45,106,260.6
|
||||
106,93,112.2
|
||||
20,165,132.5
|
||||
165,19,289.2
|
||||
89,92,323.2
|
||||
92,147,321.9
|
||||
147,37,48.2
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
104,147,277.5
|
||||
147,148,84.7
|
||||
20,101,201.2
|
||||
101,19,534.4
|
||||
19,137,245.5
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
44,42,342.3
|
||||
42,41,102.5
|
||||
44,8,789.1
|
||||
8,41,657.4
|
||||
41,117,160.5
|
||||
168,167,172.4
|
||||
167,154,165.2
|
||||
143,128,81.9
|
||||
128,10,88.2
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
15,152,135.0
|
||||
152,44,204.6
|
||||
19,77,320.7
|
||||
77,14,299.8
|
||||
14,52,127.6
|
||||
14,127,314.8
|
||||
127,39,280.4
|
||||
39,109,237.0
|
||||
31,160,116.5
|
||||
160,155,272.4
|
||||
133,91,152.8
|
||||
91,163,313.6
|
||||
150,71,221.1
|
||||
71,61,89.6
|
||||
32,160,107.7
|
||||
72,162,3274.4
|
||||
162,13,554.5
|
||||
162,40,413.9
|
||||
65,72,297.9
|
||||
72,48,21.5
|
||||
13,42,319.8
|
||||
42,40,40.7
|
||||
8,42,759.5
|
||||
42,117,58.9
|
||||
8,13,450.3
|
||||
13,117,378.5
|
||||
117,40,64.0
|
||||
46,162,391.6
|
||||
162,152,115.3
|
||||
152,108,191.4
|
||||
104,108,375.9
|
||||
108,148,311.6
|
||||
148,146,80.0
|
||||
21,90,396.9
|
||||
90,132,152.2
|
||||
101,29,252.3
|
||||
29,137,110.7
|
||||
77,22,353.8
|
||||
22,52,227.8
|
||||
52,30,186.6
|
||||
127,18,425.2
|
||||
18,109,439.1
|
||||
109,22,135.5
|
||||
168,17,232.7
|
||||
17,154,294.2
|
||||
154,5,166.3
|
||||
78,107,143.9
|
||||
107,46,236.3
|
||||
118,145,250.6
|
||||
145,96,85.1
|
||||
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
|
@ -2,6 +2,7 @@ import numpy as np
|
|||
import os
|
||||
|
||||
def load_dataset(config):
|
||||
|
||||
dataset_name = config['basic']['dataset']
|
||||
node_num = config['data']['num_nodes']
|
||||
input_dim = config['data']['input_dim']
|
||||
|
|
@ -10,4 +11,8 @@ def load_dataset(config):
|
|||
case 'EcoSolar':
|
||||
data_path = os.path.join('./data/EcoSolar.npy')
|
||||
data = np.load(data_path)[:, :node_num, :input_dim]
|
||||
case 'PEMS08':
|
||||
data_path = os.path.join('./data/PEMS08/PEMS08.npz')
|
||||
data = np.load(data_path)['data'][:, :node_num, :input_dim]
|
||||
|
||||
return data
|
||||
|
|
@ -0,0 +1,9 @@
|
|||
import numpy as np
|
||||
|
||||
def load_graph(config):
|
||||
dataset_path = config['data']['graph_pkl_filename']
|
||||
graph = np.load(dataset_path)
|
||||
# 将inf值填充为0
|
||||
graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
|
||||
|
||||
return graph
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 100
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 100
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
basic:
|
||||
dataset: PEMS08
|
||||
device: cuda:0
|
||||
mode: train
|
||||
model: STDEN
|
||||
seed: 2025
|
||||
data:
|
||||
add_day_in_week: true
|
||||
add_time_in_day: true
|
||||
batch_size: 32
|
||||
column_wise: false
|
||||
dataset_dir: data/PEMS08
|
||||
days_per_week: 7
|
||||
default_graph: true
|
||||
graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
|
||||
horizon: 24
|
||||
input_dim: 1
|
||||
lag: 24
|
||||
normalizer: std
|
||||
num_nodes: 170
|
||||
steps_per_day: 24
|
||||
test_ratio: 0.2
|
||||
tod: false
|
||||
val_batch_size: 32
|
||||
val_ratio: 0.2
|
||||
model:
|
||||
filter_type: default
|
||||
gcn_step: 2
|
||||
horizon: 12
|
||||
input_dim: 1
|
||||
l1_decay: 0
|
||||
latent_dim: 4
|
||||
n_traj_samples: 3
|
||||
nfe: false
|
||||
num_rnn_layers: 1
|
||||
ode_method: dopri5
|
||||
odeint_atol: 1.0e-05
|
||||
odeint_rtol: 1.0e-05
|
||||
output_dim: 1
|
||||
recg_type: gru
|
||||
rnn_units: 64
|
||||
save_latent: false
|
||||
seq_len: 12
|
||||
train:
|
||||
batch_size: 64
|
||||
debug: false
|
||||
early_stop: true
|
||||
early_stop_patience: 15
|
||||
epochs: 100
|
||||
grad_norm: false
|
||||
log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
|
||||
loss: mae
|
||||
lr_decay: false
|
||||
lr_decay_rate: 0.3
|
||||
lr_decay_step: 5,20,40,70
|
||||
lr_init: 0.003
|
||||
mae_thresh: None
|
||||
mape_thresh: 0.001
|
||||
max_grad_norm: 5
|
||||
output_dim: 1
|
||||
real_value: true
|
||||
weight_decay: 0
|
||||
2
main.py
2
main.py
|
|
@ -3,8 +3,6 @@
|
|||
时空数据深度学习预测项目主程序
|
||||
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
|
||||
"""
|
||||
|
||||
import os
|
||||
from utils.args_reader import config_loader
|
||||
import utils.init as init
|
||||
import torch
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
|
||||
from models.STDEN.stden_model import STDENModel
|
||||
|
||||
def model_selector(config):
|
||||
model_name = config['basic']['model']
|
||||
model = None
|
||||
match model_name:
|
||||
case 'STDEN': model = STDENModel(config)
|
||||
return model
|
||||
|
|
@ -2,6 +2,8 @@ import math
|
|||
import os
|
||||
import time
|
||||
import copy
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
|
||||
import torch
|
||||
|
|
@ -26,6 +28,10 @@ class Trainer:
|
|||
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
|
||||
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
|
||||
|
||||
# 用于收集nfe数据
|
||||
self.c = []
|
||||
self.res, self.keys = [], []
|
||||
|
||||
def _run_epoch(self, epoch, dataloader, mode):
|
||||
if mode == 'train':
|
||||
self.model.train()
|
||||
|
|
@ -37,17 +43,28 @@ class Trainer:
|
|||
total_loss = 0
|
||||
epoch_time = time.time()
|
||||
|
||||
# 清空nfe数据收集
|
||||
if mode == 'train':
|
||||
self.c.clear()
|
||||
|
||||
with torch.set_grad_enabled(optimizer_step):
|
||||
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
|
||||
for batch_idx, (data, target) in enumerate(dataloader):
|
||||
label = target[..., :self.args['output_dim']]
|
||||
output = self.model(data).to(self.args['device'])
|
||||
output, fe = self.model(data)
|
||||
|
||||
if self.args['real_value']:
|
||||
# 只对输出维度进行反归一化
|
||||
output = self._inverse_transform_output(output)
|
||||
|
||||
loss = self.loss(output, label)
|
||||
|
||||
# 收集nfe数据(仅在训练模式下)
|
||||
if mode == 'train':
|
||||
self.c.append([*fe, loss.item()])
|
||||
# 记录FE信息
|
||||
self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
|
||||
|
||||
if optimizer_step and self.optimizer is not None:
|
||||
self.optimizer.zero_grad()
|
||||
loss.backward()
|
||||
|
|
@ -69,6 +86,12 @@ class Trainer:
|
|||
avg_loss = total_loss / len(dataloader)
|
||||
self.logger.logger.info(
|
||||
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
|
||||
|
||||
# 收集nfe数据(仅在训练模式下)
|
||||
if mode == 'train':
|
||||
self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
|
||||
self.keys.append(epoch)
|
||||
|
||||
return avg_loss
|
||||
|
||||
def _inverse_transform_output(self, output):
|
||||
|
|
@ -143,6 +166,10 @@ class Trainer:
|
|||
best_test_model = copy.deepcopy(self.model.state_dict())
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
|
||||
# 保存nfe数据(如果启用)
|
||||
if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
|
||||
self._save_nfe_data()
|
||||
|
||||
if not self.args['debug']:
|
||||
torch.save(best_model, self.best_path)
|
||||
torch.save(best_test_model, self.best_test_path)
|
||||
|
|
@ -164,18 +191,24 @@ class Trainer:
|
|||
if path:
|
||||
checkpoint = torch.load(path)
|
||||
model.load_state_dict(checkpoint['state_dict'])
|
||||
model.to(args.device)
|
||||
model.to(args['device'])
|
||||
|
||||
model.eval()
|
||||
y_pred, y_true = [], []
|
||||
|
||||
# 用于收集nfe数据
|
||||
c = []
|
||||
|
||||
with torch.no_grad():
|
||||
for data, target in data_loader:
|
||||
label = target[..., :args['output_dim']]
|
||||
output = model(data)
|
||||
output, fe = model(data)
|
||||
y_pred.append(output)
|
||||
y_true.append(label)
|
||||
|
||||
# 收集nfe数据
|
||||
c.append([*fe, 0.0]) # 测试时没有loss,设为0
|
||||
|
||||
if args['real_value']:
|
||||
# 只对输出维度进行反归一化
|
||||
y_pred = Trainer._inverse_transform_output_static(torch.cat(y_pred, dim=0), args, scalers)
|
||||
|
|
@ -192,6 +225,10 @@ class Trainer:
|
|||
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
|
||||
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||
|
||||
# 保存nfe数据(如果启用)
|
||||
if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
|
||||
Trainer._save_nfe_data_static(c, model, logger)
|
||||
|
||||
# 只在需要时生成可视化图片
|
||||
if generate_viz:
|
||||
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
|
||||
|
|
@ -514,6 +551,26 @@ class Trainer:
|
|||
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
def _save_nfe_data(self):
|
||||
"""保存nfe数据到文件"""
|
||||
if not self.res:
|
||||
return
|
||||
|
||||
res = pd.concat(self.res, keys=self.keys)
|
||||
res.index.names = ['epoch', 'iter']
|
||||
|
||||
# 获取模型配置参数
|
||||
filter_type = getattr(self.model, 'filter_type', 'unknown')
|
||||
atol = getattr(self.model, 'atol', 1e-5)
|
||||
rtol = getattr(self.model, 'rtol', 1e-5)
|
||||
|
||||
# 保存nfe数据
|
||||
nfe_file = os.path.join(
|
||||
self.logger.dir_path,
|
||||
'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
|
||||
res.to_pickle(nfe_file)
|
||||
self.logger.logger.info(f"NFE data saved to {nfe_file}")
|
||||
|
||||
@staticmethod
|
||||
def _compute_sampling_threshold(global_step, k):
|
||||
return k / (k + math.exp(global_step / k))
|
||||
|
|
|
|||
|
|
@ -1,11 +1,14 @@
|
|||
from trainer.trainer import Trainer
|
||||
from trainer.ode_trainer import Trainer as ode_trainer
|
||||
|
||||
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
|
||||
lr_scheduler, kwargs):
|
||||
model_name = config['basic']['model']
|
||||
selected_Trainer = None
|
||||
match model_name:
|
||||
case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
|
||||
train_loader, val_loader, test_loader, scaler,lr_scheduler)
|
||||
train_loader, val_loader, test_loader, scaler, lr_scheduler)
|
||||
if selected_Trainer is None: raise NotImplementedError
|
||||
return selected_Trainer
|
||||
Loading…
Reference in New Issue