diff --git a/.gitignore b/.gitignore index 5d381cc..ac0c3ec 100644 --- a/.gitignore +++ b/.gitignore @@ -160,3 +160,4 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. #.idea/ +STDEN/ \ No newline at end of file diff --git a/.idea/.gitignore b/.idea/.gitignore new file mode 100644 index 0000000..35410ca --- /dev/null +++ b/.idea/.gitignore @@ -0,0 +1,8 @@ +# 默认忽略的文件 +/shelf/ +/workspace.xml +# 基于编辑器的 HTTP 客户端请求 +/httpRequests/ +# Datasource local storage ignored files +/dataSources/ +/dataSources.local.xml diff --git a/.idea/Project-I.iml b/.idea/Project-I.iml new file mode 100644 index 0000000..91f2557 --- /dev/null +++ b/.idea/Project-I.iml @@ -0,0 +1,12 @@ + + + + + + + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/Project_Default.xml b/.idea/inspectionProfiles/Project_Default.xml new file mode 100644 index 0000000..42faaf3 --- /dev/null +++ b/.idea/inspectionProfiles/Project_Default.xml @@ -0,0 +1,27 @@ + + + + \ No newline at end of file diff --git a/.idea/inspectionProfiles/profiles_settings.xml b/.idea/inspectionProfiles/profiles_settings.xml new file mode 100644 index 0000000..105ce2d --- /dev/null +++ b/.idea/inspectionProfiles/profiles_settings.xml @@ -0,0 +1,6 @@ + + + + \ No newline at end of file diff --git a/.idea/misc.xml b/.idea/misc.xml new file mode 100644 index 0000000..2f14dc7 --- /dev/null +++ b/.idea/misc.xml @@ -0,0 +1,7 @@ + + + + + + \ No newline at end of file diff --git a/.idea/modules.xml b/.idea/modules.xml new file mode 100644 index 0000000..777a606 --- /dev/null +++ b/.idea/modules.xml @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/.idea/vcs.xml b/.idea/vcs.xml new file mode 100644 index 0000000..e3f6bd5 --- /dev/null +++ b/.idea/vcs.xml @@ -0,0 +1,7 @@ + + + + + + + \ No newline at end of file diff --git a/data/PEMS08/PEMS08.csv b/data/PEMS08/PEMS08.csv new file mode 100755 index 0000000..d671be6 --- /dev/null +++ b/data/PEMS08/PEMS08.csv @@ -0,0 +1,296 @@ +from,to,cost +9,153,310.6 +153,62,330.9 +62,111,332.9 +111,11,324.2 +11,28,336.0 +28,169,133.7 +138,135,354.7 +135,133,387.9 +133,163,337.1 +163,20,352.0 +20,19,420.8 +19,14,351.3 +14,39,340.2 +39,164,350.3 +164,167,365.2 +167,70,359.0 +70,59,388.2 +59,58,305.7 +58,67,294.4 +67,66,299.5 +66,55,313.3 +55,53,332.1 +53,150,278.9 +150,61,308.4 +61,64,311.4 +64,63,243.6 +47,65,372.8 +65,48,319.4 +48,49,309.7 +49,54,320.5 +54,56,318.3 +56,57,297.9 +57,68,293.5 +68,69,342.5 +69,60,318.0 +60,17,305.9 +17,5,321.4 +5,18,402.2 +18,22,447.4 +22,30,377.5 +30,29,417.7 +29,21,360.8 +21,132,407.6 +132,134,386.9 +134,136,350.2 +123,121,326.3 +121,140,385.2 +140,118,393.0 +118,96,296.7 +96,94,398.2 +94,86,337.1 +86,78,473.8 +78,46,353.4 +46,152,385.7 +152,157,350.0 +157,35,354.4 +35,77,356.1 +77,52,354.2 +52,3,357.8 +3,16,382.4 +16,0,55.7 +42,12,335.1 +12,139,328.8 +139,168,412.6 +168,154,337.3 +154,143,370.7 +143,10,6.3 +107,105,354.6 +105,104,386.9 +104,148,362.1 +148,97,316.3 +97,101,380.7 +101,137,361.4 +137,102,365.5 +102,24,375.5 +24,166,312.2 +129,156,256.1 +156,33,329.1 +33,32,356.5 +91,89,405.6 +89,147,347.0 +147,15,351.7 +15,44,339.5 +44,41,350.8 +41,43,322.6 +43,100,338.9 +100,83,347.9 +83,87,327.2 +87,88,321.0 +88,75,335.8 +75,51,384.8 +51,73,391.1 +73,71,289.3 +31,155,260.0 +155,34,320.4 +34,128,393.3 +145,115,399.4 +115,112,328.1 +112,8,469.4 +8,117,816.2 +117,125,397.1 +125,127,372.7 +127,109,380.5 +109,161,355.5 +161,110,367.7 +110,160,102.0 +72,159,342.9 +159,50,383.3 +50,74,354.1 +74,82,350.2 +82,81,335.4 +81,99,391.6 +99,84,354.9 +84,13,306.4 +13,40,327.4 +40,162,413.9 +162,108,301.9 +108,146,317.8 +146,85,376.6 +85,90,347.0 +26,27,341.6 +27,6,359.4 +6,149,417.8 +149,126,388.0 +126,124,384.3 +124,7,763.3 +7,114,323.1 +114,113,351.6 +113,116,411.9 +116,144,262.0 +25,103,350.2 +103,23,376.3 +23,165,396.4 +165,38,381.0 +38,92,368.0 +92,37,336.3 +37,130,357.8 +130,106,532.3 +106,131,166.5 +1,2,371.6 +2,4,338.1 +4,76,429.0 +76,36,366.1 +36,158,344.5 +158,151,350.1 +151,45,358.8 +45,93,340.9 +93,80,329.9 +80,79,384.1 +79,95,335.7 +95,98,320.9 +98,119,340.3 +119,120,376.8 +120,122,393.1 +122,141,428.7 +141,142,359.3 +30,165,379.6 +165,29,41.7 +29,38,343.3 +65,72,297.9 +72,48,21.5 +17,153,375.6 +153,5,256.3 +153,62,330.9 +18,6,499.4 +6,22,254.0 +22,149,185.4 +22,4,257.9 +4,30,236.8 +30,76,307.0 +95,98,320.9 +98,144,45.1 +45,93,340.9 +93,106,112.2 +162,151,113.6 +151,108,192.9 +108,45,359.8 +146,92,311.2 +92,85,343.9 +85,37,373.2 +13,169,326.2 +169,40,96.1 +124,13,460.7 +13,7,305.5 +7,40,624.1 +124,169,145.2 +169,7,631.5 +90,132,152.2 +26,32,106.7 +9,129,148.3 +129,153,219.6 +31,26,116.0 +26,155,270.7 +9,128,142.2 +128,153,215.0 +153,167,269.7 +167,62,64.8 +62,70,332.6 +124,169,145.2 +169,7,631.5 +44,169,397.8 +169,41,124.0 +44,124,375.7 +124,41,243.9 +41,7,519.4 +6,14,289.3 +14,149,259.0 +149,39,206.9 +144,98,45.1 +19,4,326.8 +4,14,178.6 +14,76,299.0 +15,151,136.4 +151,44,203.1 +45,106,260.6 +106,93,112.2 +20,165,132.5 +165,19,289.2 +89,92,323.2 +92,147,321.9 +147,37,48.2 +133,91,152.8 +91,163,313.6 +150,71,221.1 +71,61,89.6 +78,107,143.9 +107,46,236.3 +104,147,277.5 +147,148,84.7 +20,101,201.2 +101,19,534.4 +19,137,245.5 +8,42,759.5 +42,117,58.9 +44,42,342.3 +42,41,102.5 +44,8,789.1 +8,41,657.4 +41,117,160.5 +168,167,172.4 +167,154,165.2 +143,128,81.9 +128,10,88.2 +118,145,250.6 +145,96,85.1 +15,152,135.0 +152,44,204.6 +19,77,320.7 +77,14,299.8 +14,52,127.6 +14,127,314.8 +127,39,280.4 +39,109,237.0 +31,160,116.5 +160,155,272.4 +133,91,152.8 +91,163,313.6 +150,71,221.1 +71,61,89.6 +32,160,107.7 +72,162,3274.4 +162,13,554.5 +162,40,413.9 +65,72,297.9 +72,48,21.5 +13,42,319.8 +42,40,40.7 +8,42,759.5 +42,117,58.9 +8,13,450.3 +13,117,378.5 +117,40,64.0 +46,162,391.6 +162,152,115.3 +152,108,191.4 +104,108,375.9 +108,148,311.6 +148,146,80.0 +21,90,396.9 +90,132,152.2 +101,29,252.3 +29,137,110.7 +77,22,353.8 +22,52,227.8 +52,30,186.6 +127,18,425.2 +18,109,439.1 +109,22,135.5 +168,17,232.7 +17,154,294.2 +154,5,166.3 +78,107,143.9 +107,46,236.3 +118,145,250.6 +145,96,85.1 diff --git a/data/PEMS08/PEMS08.npz b/data/PEMS08/PEMS08.npz new file mode 100755 index 0000000..f5f3e84 Binary files /dev/null and b/data/PEMS08/PEMS08.npz differ diff --git a/data/PEMS08/PEMS08_dtw_distance.npy b/data/PEMS08/PEMS08_dtw_distance.npy new file mode 100755 index 0000000..d462058 Binary files /dev/null and b/data/PEMS08/PEMS08_dtw_distance.npy differ diff --git a/data/PEMS08/PEMS08_spatial_distance.npy b/data/PEMS08/PEMS08_spatial_distance.npy new file mode 100755 index 0000000..cde744f Binary files /dev/null and b/data/PEMS08/PEMS08_spatial_distance.npy differ diff --git a/data/data_selector.py b/data/data_selector.py index 8ff652a..efe125e 100644 --- a/data/data_selector.py +++ b/data/data_selector.py @@ -2,6 +2,7 @@ import numpy as np import os def load_dataset(config): + dataset_name = config['basic']['dataset'] node_num = config['data']['num_nodes'] input_dim = config['data']['input_dim'] @@ -10,4 +11,8 @@ def load_dataset(config): case 'EcoSolar': data_path = os.path.join('./data/EcoSolar.npy') data = np.load(data_path)[:, :node_num, :input_dim] + case 'PEMS08': + data_path = os.path.join('./data/PEMS08/PEMS08.npz') + data = np.load(data_path)['data'][:, :node_num, :input_dim] + return data \ No newline at end of file diff --git a/data/graph_loader.py b/data/graph_loader.py new file mode 100644 index 0000000..2c3b8c2 --- /dev/null +++ b/data/graph_loader.py @@ -0,0 +1,9 @@ +import numpy as np + +def load_graph(config): + dataset_path = config['data']['graph_pkl_filename'] + graph = np.load(dataset_path) + # 将inf值填充为0 + graph = np.nan_to_num(graph, nan=0.0, posinf=0.0, neginf=0.0) + + return graph diff --git a/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml b/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml new file mode 100644 index 0000000..95a755c --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-00:53:52/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 100 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:53:52 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml b/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml new file mode 100644 index 0000000..9582c49 --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-00:56:54/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 100 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-00:56:54 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml new file mode 100644 index 0000000..94c3f70 --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-01:07:23/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 170 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:07:23 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml new file mode 100644 index 0000000..3e0b106 --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-01:08:24/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 170 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:08:24 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml new file mode 100644 index 0000000..5d4d8d0 --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-01:09:13/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 170 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:13 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml b/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml new file mode 100644 index 0000000..c6ebf04 --- /dev/null +++ b/exp/PEMS08_STDEN_2025/09/03-01:09:53/config.yaml @@ -0,0 +1,62 @@ +basic: + dataset: PEMS08 + device: cuda:0 + mode: train + model: STDEN + seed: 2025 +data: + add_day_in_week: true + add_time_in_day: true + batch_size: 32 + column_wise: false + dataset_dir: data/PEMS08 + days_per_week: 7 + default_graph: true + graph_pkl_filename: data/PEMS08/PEMS08_spatial_distance.npy + horizon: 24 + input_dim: 1 + lag: 24 + normalizer: std + num_nodes: 170 + steps_per_day: 24 + test_ratio: 0.2 + tod: false + val_batch_size: 32 + val_ratio: 0.2 +model: + filter_type: default + gcn_step: 2 + horizon: 12 + input_dim: 1 + l1_decay: 0 + latent_dim: 4 + n_traj_samples: 3 + nfe: false + num_rnn_layers: 1 + ode_method: dopri5 + odeint_atol: 1.0e-05 + odeint_rtol: 1.0e-05 + output_dim: 1 + recg_type: gru + rnn_units: 64 + save_latent: false + seq_len: 12 +train: + batch_size: 64 + debug: false + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + log_dir: /home/czzhangheng/code/Project-I/exp/PEMS08_STDEN_2025/09/03-01:09:53 + loss: mae + lr_decay: false + lr_decay_rate: 0.3 + lr_decay_step: 5,20,40,70 + lr_init: 0.003 + mae_thresh: None + mape_thresh: 0.001 + max_grad_norm: 5 + output_dim: 1 + real_value: true + weight_decay: 0 diff --git a/main.py b/main.py index ea2e00d..b2f0ef5 100644 --- a/main.py +++ b/main.py @@ -3,8 +3,6 @@ 时空数据深度学习预测项目主程序 专门处理时空数据格式 (batch_size, seq_len, num_nodes, features) """ - -import os from utils.args_reader import config_loader import utils.init as init import torch diff --git a/models/model_selector.py b/models/model_selector.py index fced561..ab80d63 100644 --- a/models/model_selector.py +++ b/models/model_selector.py @@ -1,6 +1,8 @@ - +from models.STDEN.stden_model import STDENModel def model_selector(config): model_name = config['basic']['model'] model = None + match model_name: + case 'STDEN': model = STDENModel(config) return model \ No newline at end of file diff --git a/trainer/ode_trainer.py b/trainer/ode_trainer.py index 898cd84..a448590 100644 --- a/trainer/ode_trainer.py +++ b/trainer/ode_trainer.py @@ -2,6 +2,8 @@ import math import os import time import copy +import pandas as pd +import numpy as np from tqdm import tqdm import torch @@ -25,6 +27,10 @@ class Trainer: self.best_path = os.path.join(logger.dir_path, 'best_model.pth') self.best_test_path = os.path.join(logger.dir_path, 'best_test_model.pth') self.loss_figure_path = os.path.join(logger.dir_path, 'loss.png') + + # 用于收集nfe数据 + self.c = [] + self.res, self.keys = [], [] def _run_epoch(self, epoch, dataloader, mode): if mode == 'train': @@ -36,18 +42,29 @@ class Trainer: total_loss = 0 epoch_time = time.time() + + # 清空nfe数据收集 + if mode == 'train': + self.c.clear() with torch.set_grad_enabled(optimizer_step): with tqdm(total=len(dataloader), desc=f'{mode.capitalize()} Epoch {epoch}') as pbar: for batch_idx, (data, target) in enumerate(dataloader): label = target[..., :self.args['output_dim']] - output = self.model(data).to(self.args['device']) + output, fe = self.model(data) if self.args['real_value']: # 只对输出维度进行反归一化 output = self._inverse_transform_output(output) loss = self.loss(output, label) + + # 收集nfe数据(仅在训练模式下) + if mode == 'train': + self.c.append([*fe, loss.item()]) + # 记录FE信息 + self.logger.logger.debug("FE: number - {}, time - {:.3f} s, err - {:.3f}".format(*fe, loss.item())) + if optimizer_step and self.optimizer is not None: self.optimizer.zero_grad() loss.backward() @@ -69,6 +86,12 @@ class Trainer: avg_loss = total_loss / len(dataloader) self.logger.logger.info( f'{mode.capitalize()} Epoch {epoch}: average Loss: {avg_loss:.6f}, time: {time.time() - epoch_time:.2f} s') + + # 收集nfe数据(仅在训练模式下) + if mode == 'train': + self.res.append(pd.DataFrame(self.c, columns=['nfe', 'time', 'err'])) + self.keys.append(epoch) + return avg_loss def _inverse_transform_output(self, output): @@ -143,6 +166,10 @@ class Trainer: best_test_model = copy.deepcopy(self.model.state_dict()) torch.save(best_test_model, self.best_test_path) + # 保存nfe数据(如果启用) + if hasattr(self.args, 'nfe') and bool(self.args.get('nfe', False)): + self._save_nfe_data() + if not self.args['debug']: torch.save(best_model, self.best_path) torch.save(best_test_model, self.best_test_path) @@ -164,17 +191,23 @@ class Trainer: if path: checkpoint = torch.load(path) model.load_state_dict(checkpoint['state_dict']) - model.to(args.device) + model.to(args['device']) model.eval() y_pred, y_true = [], [] + + # 用于收集nfe数据 + c = [] with torch.no_grad(): for data, target in data_loader: label = target[..., :args['output_dim']] - output = model(data) + output, fe = model(data) y_pred.append(output) y_true.append(label) + + # 收集nfe数据 + c.append([*fe, 0.0]) # 测试时没有loss,设为0 if args['real_value']: # 只对输出维度进行反归一化 @@ -192,6 +225,10 @@ class Trainer: mae, rmse, mape = logger.all_metrics(y_pred, y_true, args['mae_thresh'], args['mape_thresh']) logger.logger.info(f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") + # 保存nfe数据(如果启用) + if hasattr(args, 'nfe') and bool(args.get('nfe', False)): + Trainer._save_nfe_data_static(c, model, logger) + # 只在需要时生成可视化图片 if generate_viz: save_dir = logger.dir_path if hasattr(logger, 'dir_path') else './logs' @@ -514,6 +551,26 @@ class Trainer: plt.savefig(summary_path, dpi=300, bbox_inches='tight') plt.close() + def _save_nfe_data(self): + """保存nfe数据到文件""" + if not self.res: + return + + res = pd.concat(self.res, keys=self.keys) + res.index.names = ['epoch', 'iter'] + + # 获取模型配置参数 + filter_type = getattr(self.model, 'filter_type', 'unknown') + atol = getattr(self.model, 'atol', 1e-5) + rtol = getattr(self.model, 'rtol', 1e-5) + + # 保存nfe数据 + nfe_file = os.path.join( + self.logger.dir_path, + 'nfe_{}_a{}_r{}.pkl'.format(filter_type, int(atol*1e5), int(rtol*1e5))) + res.to_pickle(nfe_file) + self.logger.logger.info(f"NFE data saved to {nfe_file}") + @staticmethod def _compute_sampling_threshold(global_step, k): return k / (k + math.exp(global_step / k)) diff --git a/trainer/trainer_selector.py b/trainer/trainer_selector.py index d5a3751..92b46d2 100644 --- a/trainer/trainer_selector.py +++ b/trainer/trainer_selector.py @@ -1,11 +1,14 @@ from trainer.trainer import Trainer +from trainer.ode_trainer import Trainer as ode_trainer def select_trainer(config, model, loss, optimizer, train_loader, val_loader, test_loader, scaler, lr_scheduler, kwargs): model_name = config['basic']['model'] selected_Trainer = None match model_name: + case 'STDEN': selected_Trainer = ode_trainer(config, model, loss, optimizer, + train_loader, val_loader, test_loader, scaler, lr_scheduler) case _: selected_Trainer = Trainer(config, model, loss, optimizer, - train_loader, val_loader, test_loader, scaler,lr_scheduler) + train_loader, val_loader, test_loader, scaler, lr_scheduler) if selected_Trainer is None: raise NotImplementedError return selected_Trainer \ No newline at end of file