diff --git a/.gitignore b/.gitignore
index 5d381cc..ac0c3ec 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,3 +160,4 @@ cython_debug/
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
+STDEN/
\ No newline at end of file
diff --git a/.idea/.gitignore b/.idea/.gitignore
new file mode 100644
index 0000000..35410ca
--- /dev/null
+++ b/.idea/.gitignore
@@ -0,0 +1,8 @@
+# 默认忽略的文件
+/shelf/
+/workspace.xml
+# 基于编辑器的 HTTP 客户端请求
+/httpRequests/
+# Datasource local storage ignored files
+/dataSources/
+/dataSources.local.xml
diff --git a/.idea/Project-I.iml b/.idea/Project-I.iml
new file mode 100644
index 0000000..91f2557
--- /dev/null
+++ b/.idea/Project-I.iml
@@ -0,0 +1,12 @@
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml
new file mode 100644
index 0000000..42faaf3
--- /dev/null
+++ b/.idea/inspectionProfiles/Project_Default.xml
@@ -0,0 +1,27 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml
new file mode 100644
index 0000000..105ce2d
--- /dev/null
+++ b/.idea/inspectionProfiles/profiles_settings.xml
@@ -0,0 +1,6 @@
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
new file mode 100644
index 0000000..2f14dc7
--- /dev/null
+++ b/.idea/misc.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
new file mode 100644
index 0000000..777a606
--- /dev/null
+++ b/.idea/modules.xml
@@ -0,0 +1,8 @@
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
new file mode 100644
index 0000000..e3f6bd5
--- /dev/null
+++ b/.idea/vcs.xml
@@ -0,0 +1,7 @@
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/data/PEMS08/PEMS08.csv b/data/PEMS08/PEMS08.csv
new file mode 100755
index 0000000..d671be6
--- /dev/null
+++ b/data/PEMS08/PEMS08.csv
@@ -0,0 +1,296 @@
+from,to,cost
+9,153,310.6
+153,62,330.9
+62,111,332.9
+111,11,324.2
+11,28,336.0
+28,169,133.7
+138,135,354.7
+135,133,387.9
+133,163,337.1
+163,20,352.0
+20,19,420.8
+19,14,351.3
+14,39,340.2
+39,164,350.3
+164,167,365.2
+167,70,359.0
+70,59,388.2
+59,58,305.7
+58,67,294.4
+67,66,299.5
+66,55,313.3
+55,53,332.1
+53,150,278.9
+150,61,308.4
+61,64,311.4
+64,63,243.6
+47,65,372.8
+65,48,319.4
+48,49,309.7
+49,54,320.5
+54,56,318.3
+56,57,297.9
+57,68,293.5
+68,69,342.5
+69,60,318.0
+60,17,305.9
+17,5,321.4
+5,18,402.2
+18,22,447.4
+22,30,377.5
+30,29,417.7
+29,21,360.8
+21,132,407.6
+132,134,386.9
+134,136,350.2
+123,121,326.3
+121,140,385.2
+140,118,393.0
+118,96,296.7
+96,94,398.2
+94,86,337.1
+86,78,473.8
+78,46,353.4
+46,152,385.7
+152,157,350.0
+157,35,354.4
+35,77,356.1
+77,52,354.2
+52,3,357.8
+3,16,382.4
+16,0,55.7
+42,12,335.1
+12,139,328.8
+139,168,412.6
+168,154,337.3
+154,143,370.7
+143,10,6.3
+107,105,354.6
+105,104,386.9
+104,148,362.1
+148,97,316.3
+97,101,380.7
+101,137,361.4
+137,102,365.5
+102,24,375.5
+24,166,312.2
+129,156,256.1
+156,33,329.1
+33,32,356.5
+91,89,405.6
+89,147,347.0
+147,15,351.7
+15,44,339.5
+44,41,350.8
+41,43,322.6
+43,100,338.9
+100,83,347.9
+83,87,327.2
+87,88,321.0
+88,75,335.8
+75,51,384.8
+51,73,391.1
+73,71,289.3
+31,155,260.0
+155,34,320.4
+34,128,393.3
+145,115,399.4
+115,112,328.1
+112,8,469.4
+8,117,816.2
+117,125,397.1
+125,127,372.7
+127,109,380.5
+109,161,355.5
+161,110,367.7
+110,160,102.0
+72,159,342.9
+159,50,383.3
+50,74,354.1
+74,82,350.2
+82,81,335.4
+81,99,391.6
+99,84,354.9
+84,13,306.4
+13,40,327.4
+40,162,413.9
+162,108,301.9
+108,146,317.8
+146,85,376.6
+85,90,347.0
+26,27,341.6
+27,6,359.4
+6,149,417.8
+149,126,388.0
+126,124,384.3
+124,7,763.3
+7,114,323.1
+114,113,351.6
+113,116,411.9
+116,144,262.0
+25,103,350.2
+103,23,376.3
+23,165,396.4
+165,38,381.0
+38,92,368.0
+92,37,336.3
+37,130,357.8
+130,106,532.3
+106,131,166.5
+1,2,371.6
+2,4,338.1
+4,76,429.0
+76,36,366.1
+36,158,344.5
+158,151,350.1
+151,45,358.8
+45,93,340.9
+93,80,329.9
+80,79,384.1
+79,95,335.7
+95,98,320.9
+98,119,340.3
+119,120,376.8
+120,122,393.1
+122,141,428.7
+141,142,359.3
+30,165,379.6
+165,29,41.7
+29,38,343.3
+65,72,297.9
+72,48,21.5
+17,153,375.6
+153,5,256.3
+153,62,330.9
+18,6,499.4
+6,22,254.0
+22,149,185.4
+22,4,257.9
+4,30,236.8
+30,76,307.0
+95,98,320.9
+98,144,45.1
+45,93,340.9
+93,106,112.2
+162,151,113.6
+151,108,192.9
+108,45,359.8
+146,92,311.2
+92,85,343.9
+85,37,373.2
+13,169,326.2
+169,40,96.1
+124,13,460.7
+13,7,305.5
+7,40,624.1
+124,169,145.2
+169,7,631.5
+90,132,152.2
+26,32,106.7
+9,129,148.3
+129,153,219.6
+31,26,116.0
+26,155,270.7
+9,128,142.2
+128,153,215.0
+153,167,269.7
+167,62,64.8
+62,70,332.6
+124,169,145.2
+169,7,631.5
+44,169,397.8
+169,41,124.0
+44,124,375.7
+124,41,243.9
+41,7,519.4
+6,14,289.3
+14,149,259.0
+149,39,206.9
+144,98,45.1
+19,4,326.8
+4,14,178.6
+14,76,299.0
+15,151,136.4
+151,44,203.1
+45,106,260.6
+106,93,112.2
+20,165,132.5
+165,19,289.2
+89,92,323.2
+92,147,321.9
+147,37,48.2
+133,91,152.8
+91,163,313.6
+150,71,221.1
+71,61,89.6
+78,107,143.9
+107,46,236.3
+104,147,277.5
+147,148,84.7
+20,101,201.2
+101,19,534.4
+19,137,245.5
+8,42,759.5
+42,117,58.9
+44,42,342.3
+42,41,102.5
+44,8,789.1
+8,41,657.4
+41,117,160.5
+168,167,172.4
+167,154,165.2
+143,128,81.9
+128,10,88.2
+118,145,250.6
+145,96,85.1
+15,152,135.0
+152,44,204.6
+19,77,320.7
+77,14,299.8
+14,52,127.6
+14,127,314.8
+127,39,280.4
+39,109,237.0
+31,160,116.5
+160,155,272.4
+133,91,152.8
+91,163,313.6
+150,71,221.1
+71,61,89.6
+32,160,107.7
+72,162,3274.4
+162,13,554.5
+162,40,413.9
+65,72,297.9
+72,48,21.5
+13,42,319.8
+42,40,40.7
+8,42,759.5
+42,117,58.9
+8,13,450.3
+13,117,378.5
+117,40,64.0
+46,162,391.6
+162,152,115.3
+152,108,191.4
+104,108,375.9
+108,148,311.6
+148,146,80.0
+21,90,396.9
+90,132,152.2
+101,29,252.3
+29,137,110.7
+77,22,353.8
+22,52,227.8
+52,30,186.6
+127,18,425.2
+18,109,439.1
+109,22,135.5
+168,17,232.7
+17,154,294.2
+154,5,166.3
+78,107,143.9
+107,46,236.3
+118,145,250.6
+145,96,85.1
diff --git a/data/PEMS08/PEMS08.npz b/data/PEMS08/PEMS08.npz
new file mode 100755
index 0000000..f5f3e84
Binary files /dev/null and b/data/PEMS08/PEMS08.npz differ
diff --git a/data/PEMS08/PEMS08_dtw_distance.npy b/data/PEMS08/PEMS08_dtw_distance.npy
new file mode 100755
index 0000000..d462058
Binary files /dev/null and b/data/PEMS08/PEMS08_dtw_distance.npy differ
diff --git a/data/PEMS08/PEMS08_spatial_distance.npy b/data/PEMS08/PEMS08_spatial_distance.npy
new file mode 100755
index 0000000..cde744f
Binary files /dev/null and b/data/PEMS08/PEMS08_spatial_distance.npy differ
diff --git a/data/data_selector.py b/data/data_selector.py
index 8ff652a..efe125e 100644
--- a/data/data_selector.py
+++ b/data/data_selector.py
@@ -2,6 +2,7 @@ import numpy as np
import os
def load_dataset(config):
+
dataset_name = config['basic']['dataset']
node_num = config['data']['num_nodes']
input_dim = config['data']['input_dim']
@@ -10,4 +11,8 @@ def load_dataset(config):
case 'EcoSolar':
data_path = os.path.join('./data/EcoSolar.npy')
data = np.load(data_path)[:, :node_num, :input_dim]
+ case 'PEMS08':
+ data_path = os.path.join('./data/PEMS08/PEMS08.npz')
+ data = np.load(data_path)['data'][:, :node_num, :input_dim]
+
return data
\ No newline at end of file
diff --git a/data/graph_loader.py b/data/graph_loader.py
new file mode 100644
index 0000000..2c3b8c2
--- /dev/null
+++ b/data/graph_loader.py
@@ -0,0 +1,9 @@
+import numpy as np
+
+def load_graph(config):
+ dataset_path = config['data']['graph_pkl_filename']
+ graph = np.load(dataset_path)
+ # 将inf值填充为0
+ graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0)
+
+ return graph
diff --git a/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml b/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml
new file mode 100644
index 0000000..95a755c
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 100
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml b/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml
new file mode 100644
index 0000000..9582c49
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 100
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml
new file mode 100644
index 0000000..94c3f70
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 170
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml
new file mode 100644
index 0000000..3e0b106
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 170
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml
new file mode 100644
index 0000000..5d4d8d0
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 170
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml
new file mode 100644
index 0000000..c6ebf04
--- /dev/null
+++ b/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml
@@ -0,0 +1,62 @@
+basic:
+ dataset: PEMS08
+ device: cuda:0
+ mode: train
+ model: STDEN
+ seed: 2025
+data:
+ add_day_in_week: true
+ add_time_in_day: true
+ batch_size: 32
+ column_wise: false
+ dataset_dir: data/PEMS08
+ days_per_week: 7
+ default_graph: true
+ graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy
+ horizon: 24
+ input_dim: 1
+ lag: 24
+ normalizer: std
+ num_nodes: 170
+ steps_per_day: 24
+ test_ratio: 0.2
+ tod: false
+ val_batch_size: 32
+ val_ratio: 0.2
+model:
+ filter_type: default
+ gcn_step: 2
+ horizon: 12
+ input_dim: 1
+ l1_decay: 0
+ latent_dim: 4
+ n_traj_samples: 3
+ nfe: false
+ num_rnn_layers: 1
+ ode_method: dopri5
+ odeint_atol: 1.0e-05
+ odeint_rtol: 1.0e-05
+ output_dim: 1
+ recg_type: gru
+ rnn_units: 64
+ save_latent: false
+ seq_len: 12
+train:
+ batch_size: 64
+ debug: false
+ early_stop: true
+ early_stop_patience: 15
+ epochs: 100
+ grad_norm: false
+ log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53
+ loss: mae
+ lr_decay: false
+ lr_decay_rate: 0.3
+ lr_decay_step: 5,20,40,70
+ lr_init: 0.003
+ mae_thresh: None
+ mape_thresh: 0.001
+ max_grad_norm: 5
+ output_dim: 1
+ real_value: true
+ weight_decay: 0
diff --git a/main.py b/main.py
index ea2e00d..b2f0ef5 100644
--- a/main.py
+++ b/main.py
@@ -3,8 +3,6 @@
时空数据深度学习预测项目主程序
专门处理时空数据格式 (batch_size, seq_len, num_nodes, features)
"""
-
-import os
from utils.args_reader import config_loader
import utils.init as init
import torch
diff --git a/models/model_selector.py b/models/model_selector.py
index fced561..ab80d63 100644
--- a/models/model_selector.py
+++ b/models/model_selector.py
@@ -1,6 +1,8 @@
-
+from models.STDEN.stden_model import STDENModel
def model_selector(config):
model_name = config['basic']['model']
model = None
+ match model_name:
+ case 'STDEN': model = STDENModel(config)
return model
\ No newline at end of file
diff --git a/trainer/ode_trainer.py b/trainer/ode_trainer.py
index 898cd84..a448590 100644
--- a/trainer/ode_trainer.py
+++ b/trainer/ode_trainer.py
@@ -2,6 +2,8 @@ import math
import os
import time
import copy
+import pandas as pd
+import numpy as np
from tqdm import tqdm
import torch
@@ -25,6 +27,10 @@ class Trainer:
self.best_path = os.path.join(logger.dir_path, 'best_model.pth')
self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth')
self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png')
+
+ # 用于收集nfe数据
+ self.c = []
+ self.res, self.keys = [], []
def _run_epoch(self, epoch, dataloader, mode):
if mode == 'train':
@@ -36,18 +42,29 @@ class Trainer:
total_loss = 0
epoch_time = time.time()
+
+ # 清空nfe数据收集
+ if mode == 'train':
+ self.c.clear()
with torch.set_grad_enabled(optimizer_step):
with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar:
for batch_idx, (data, target) in enumerate(dataloader):
label = target[..., :self.args['output_dim']]
- output = self.model(data).to(self.args['device'])
+ output, fe = self.model(data)
if self.args['real_value']:
# 只对输出维度进行反归一化
output = self._inverse_transform_output(output)
loss = self.loss(output, label)
+
+ # 收集nfe数据(仅在训练模式下)
+ if mode == 'train':
+ self.c.append([*fe, loss.item()])
+ # 记录FE信息
+ self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item()))
+
if optimizer_step and self.optimizer is not None:
self.optimizer.zero_grad()
loss.backward()
@@ -69,6 +86,12 @@ class Trainer:
avg_loss = total_loss / len(dataloader)
self.logger.logger.info(
f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s')
+
+ # 收集nfe数据(仅在训练模式下)
+ if mode == 'train':
+ self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err']))
+ self.keys.append(epoch)
+
return avg_loss
def _inverse_transform_output(self, output):
@@ -143,6 +166,10 @@ class Trainer:
best_test_model = copy.deepcopy(self.model.state_dict())
torch.save(best_test_model, self.best_test_path)
+ # 保存nfe数据(如果启用)
+ if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)):
+ self._save_nfe_data()
+
if not self.args['debug']:
torch.save(best_model, self.best_path)
torch.save(best_test_model, self.best_test_path)
@@ -164,17 +191,23 @@ class Trainer:
if path:
checkpoint = torch.load(path)
model.load_state_dict(checkpoint['state_dict'])
- model.to(args.device)
+ model.to(args['device'])
model.eval()
y_pred, y_true = [], []
+
+ # 用于收集nfe数据
+ c = []
with torch.no_grad():
for data, target in data_loader:
label = target[..., :args['output_dim']]
- output = model(data)
+ output, fe = model(data)
y_pred.append(output)
y_true.append(label)
+
+ # 收集nfe数据
+ c.append([*fe, 0.0]) # 测试时没有loss,设为0
if args['real_value']:
# 只对输出维度进行反归一化
@@ -192,6 +225,10 @@ class Trainer:
mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh'])
logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
+ # 保存nfe数据(如果启用)
+ if hasattr(args, 'nfe') and bool(args.get('nfe', False)):
+ Trainer._save_nfe_data_static(c, model, logger)
+
# 只在需要时生成可视化图片
if generate_viz:
save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs'
@@ -514,6 +551,26 @@ class Trainer:
plt.savefig(summary_path, dpi=300, bbox_inches='tight')
plt.close()
+ def _save_nfe_data(self):
+ """保存nfe数据到文件"""
+ if not self.res:
+ return
+
+ res = pd.concat(self.res, keys=self.keys)
+ res.index.names = ['epoch', 'iter']
+
+ # 获取模型配置参数
+ filter_type = getattr(self.model, 'filter_type', 'unknown')
+ atol = getattr(self.model, 'atol', 1e-5)
+ rtol = getattr(self.model, 'rtol', 1e-5)
+
+ # 保存nfe数据
+ nfe_file = os.path.join(
+ self.logger.dir_path,
+ 'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5)))
+ res.to_pickle(nfe_file)
+ self.logger.logger.info(f"NFE data saved to {nfe_file}")
+
@staticmethod
def _compute_sampling_threshold(global_step, k):
return k / (k + math.exp(global_step / k))
diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py
index d5a3751..92b46d2 100644
--- a/trainer/trainer_selector.py
+++ b/trainer/trainer_selector.py
@@ -1,11 +1,14 @@
from trainer.trainer import Trainer
+from trainer.ode_trainer import Trainer as ode_trainer
def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler,
lr_scheduler, kwargs):
model_name = config['basic']['model']
selected_Trainer = None
match model_name:
+ case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer,
+ train_loader, val_loader, test_loader, scaler, lr_scheduler)
case _: selected_Trainer = Trainer(config, model, loss, optimizer,
- train_loader, val_loader, test_loader, scaler,lr_scheduler)
+ train_loader, val_loader, test_loader, scaler, lr_scheduler)
if selected_Trainer is None: raise NotImplementedError
return selected_Trainer
\ No newline at end of file