diff --git a/.vscode/launch.json b/.vscode/launch.json index 4f992f0..f8690b7 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -28,6 +28,38 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMSD8.yaml" }, + { + "name": "STID-BJTaxi-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/BJTaxi_Inflow.yaml" + }, + { + "name": "STID-BJTaxi-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/BJTaxi_Outflow.yaml" + }, + { + "name": "STID-NYCBike-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/NYCBike_Inflow.yaml" + }, + { + "name": "STID-NYCBike-OutFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/STID/NYCBike_Outflow.yaml" + }, { "name": "REPST-BJTaxi-InFlow", "type": "debugpy", @@ -36,6 +68,22 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" }, + { + "name": "REPST-NYCBike-outflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/NYCBike-outflow.yaml" + }, + { + "name": "REPST-NYCBike-inflow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/NYCBike-inflow.yaml" + }, { "name": "REPST-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml index 37577a2..5918c94 100755 --- a/config/REPST/BJTaxi-Inflow.yaml +++ b/config/REPST/BJTaxi-Inflow.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 24 - lag: 24 + horizon: 12 + lag: 12 normalizer: std num_nodes: 1024 steps_per_day: 48 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 24 - seq_len: 24 + pred_len: 12 + seq_len: 12 patch_len: 6 stride: 7 dropout: 0.2 diff --git a/config/REPST/BJTaxi-outflow.yaml b/config/REPST/BJTaxi-outflow.yaml new file mode 100755 index 0000000..f9994ee --- /dev/null +++ b/config/REPST/BJTaxi-outflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "BJTaxi-OutFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/NYCBike-inflow.yaml b/config/REPST/NYCBike-inflow.yaml new file mode 100755 index 0000000..c7b0bd7 --- /dev/null +++ b/config/REPST/NYCBike-inflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "NYCBike-InFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/NYCBike-outflow.yaml b/config/REPST/NYCBike-outflow.yaml new file mode 100755 index 0000000..a03e3c8 --- /dev/null +++ b/config/REPST/NYCBike-outflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "NYCBike-OutFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 128 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/STID/BJTaxi_Inflow.yaml b/config/STID/BJTaxi_Inflow.yaml new file mode 100755 index 0000000..a3cc762 --- /dev/null +++ b/config/STID/BJTaxi_Inflow.yaml @@ -0,0 +1,67 @@ +basic: + dataset: "BJTaxi-InFlow" + mode: "train" + device: "cuda:0" + model: "STID" + seed: 2023 + +data: + num_nodes: 1024 + lag: 24 + horizon: 24 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 48 + days_per_week: 7 + input_dim: 1 + output_dim: 1 + batch_size: 64 + +model: + input_dim: 3 + output_dim: 1 + history: 24 + horizon: 24 + num_nodes: 1024 + input_len: 24 + embed_dim: 32 + output_len: 24 + num_layer: 3 + if_node: True + node_dim: 32 + if_T_i_D: True + if_D_i_W: True + temp_dim_tid: 32 + temp_dim_diw: 32 + time_of_day_size: 288 + day_of_week_size: 7 + batch_size: 64 + + +train: + loss_func: mae + seed: 1 + batch_size: 64 + epochs: 100 + lr_init: 0.002 + weight_decay: 0.0001 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "1,50,80" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + debug: true + output_dim: 1 + mae_thresh: null + mape_thresh: 0.0 + log_step: 200 + plot: False diff --git a/config/STID/BJTaxi_Outflow.yaml b/config/STID/BJTaxi_Outflow.yaml new file mode 100755 index 0000000..9e59593 --- /dev/null +++ b/config/STID/BJTaxi_Outflow.yaml @@ -0,0 +1,67 @@ +basic: + dataset: "BJTaxi-OutFlow" + mode: "train" + device: "cuda:0" + model: "STID" + seed: 2023 + +data: + num_nodes: 1024 + lag: 24 + horizon: 24 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 48 + days_per_week: 7 + input_dim: 1 + output_dim: 1 + batch_size: 64 + +model: + input_dim: 3 + output_dim: 1 + history: 24 + horizon: 24 + num_nodes: 1024 + input_len: 24 + embed_dim: 32 + output_len: 24 + num_layer: 3 + if_node: True + node_dim: 32 + if_T_i_D: True + if_D_i_W: True + temp_dim_tid: 32 + temp_dim_diw: 32 + time_of_day_size: 288 + day_of_week_size: 7 + batch_size: 64 + + +train: + loss_func: mae + seed: 1 + batch_size: 64 + epochs: 100 + lr_init: 0.002 + weight_decay: 0.0001 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "1,50,80" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + debug: true + output_dim: 1 + mae_thresh: null + mape_thresh: 0.001 + log_step: 200 + plot: False diff --git a/config/STID/NYCBike_Inflow.yaml b/config/STID/NYCBike_Inflow.yaml new file mode 100755 index 0000000..506a3fd --- /dev/null +++ b/config/STID/NYCBike_Inflow.yaml @@ -0,0 +1,67 @@ +basic: + dataset: "NYCBike-InFlow" + mode: "train" + device: "cuda:0" + model: "STID" + seed: 2023 + +data: + num_nodes: 128 + lag: 24 + horizon: 24 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 48 + days_per_week: 7 + input_dim: 1 + output_dim: 1 + batch_size: 64 + +model: + input_dim: 3 + output_dim: 1 + history: 24 + horizon: 24 + num_nodes: 128 + input_len: 24 + embed_dim: 32 + output_len: 24 + num_layer: 3 + if_node: True + node_dim: 32 + if_T_i_D: True + if_D_i_W: True + temp_dim_tid: 32 + temp_dim_diw: 32 + time_of_day_size: 288 + day_of_week_size: 7 + batch_size: 64 + + +train: + loss_func: mae + seed: 1 + batch_size: 64 + epochs: 100 + lr_init: 0.002 + weight_decay: 0.0001 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "1,50,80" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + debug: true + output_dim: 1 + mae_thresh: null + mape_thresh: 0.001 + log_step: 200 + plot: False diff --git a/config/STID/NYCBike_Outflow.yaml b/config/STID/NYCBike_Outflow.yaml new file mode 100755 index 0000000..0630cce --- /dev/null +++ b/config/STID/NYCBike_Outflow.yaml @@ -0,0 +1,67 @@ +basic: + dataset: "NYCBike-OutFlow" + mode: "train" + device: "cuda:0" + model: "STID" + seed: 2023 + +data: + num_nodes: 128 + lag: 24 + horizon: 24 + val_ratio: 0.2 + test_ratio: 0.2 + tod: False + normalizer: std + column_wise: False + default_graph: True + add_time_in_day: True + add_day_in_week: True + steps_per_day: 48 + days_per_week: 7 + input_dim: 1 + output_dim: 1 + batch_size: 64 + +model: + input_dim: 3 + output_dim: 1 + history: 24 + horizon: 24 + num_nodes: 128 + input_len: 24 + embed_dim: 32 + output_len: 24 + num_layer: 3 + if_node: True + node_dim: 32 + if_T_i_D: True + if_D_i_W: True + temp_dim_tid: 32 + temp_dim_diw: 32 + time_of_day_size: 288 + day_of_week_size: 7 + batch_size: 64 + + +train: + loss_func: mae + seed: 1 + batch_size: 64 + epochs: 100 + lr_init: 0.002 + weight_decay: 0.0001 + lr_decay: False + lr_decay_rate: 0.3 + lr_decay_step: "1,50,80" + early_stop: True + early_stop_patience: 15 + grad_norm: False + max_grad_norm: 5 + real_value: True + debug: true + output_dim: 1 + mae_thresh: null + mape_thresh: 0.001 + log_step: 200 + plot: False diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 224b6fc..e0b23e1 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -59,6 +59,18 @@ def load_st_dataset(config): data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) case "BJTaxi-OutFlow": data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) + case "NYCBike-InFlow": + data_path = os.path.join("./data/NYCBike/NYC16x8.h5") + with h5py.File(data_path, 'r') as f: + data = f['data'][:].astype(np.float32) + data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) + data = data[:, :, 0:1] + case "NYCBike-OutFlow": + data_path = os.path.join("./data/NYCBike/NYC16x8.h5") + with h5py.File(data_path, 'r') as f: + data = f['data'][:].astype(np.float32) + data = data.transpose(0,2,3,1).reshape(-1, 16*8, 2) + data = data[:, :, 1:2] case _: raise ValueError(f"Unsupported dataset: {dataset}") diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 6ceeb2a..5b709a4 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -67,7 +67,7 @@ class repst(nn.Module): x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) - masks = self.word_choice(self.mapping_layer.weight.data.permute(1,0)) + masks = self.word_choice(self.mapping_layer.weight.data.permute(1, 0)) source_embeddings = self.word_embeddings[masks==1] enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) diff --git a/utils/Download_data.py b/utils/Download_data.py index 88eeebe..f31b937 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -161,6 +161,13 @@ def check_and_download_data(): download_github_data(file_path, taxi_bj_floder) # 下载后更新缺失列表 missing_list = detect_data_integrity(data_dir, file_tree) + + # 检查并下载TaxiBJ数据 + if "NYCBike" in missing_list: + nycbike_bj_floder = os.path.join(data_dir, "NYCBike") + download_and_extract("http://code.zhang-heng.com/static/NYCBike.7z", data_dir) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载pems, bay, metr-la, solar-energy数据 kaggle_map = { diff --git a/utils/dataset.json b/utils/dataset.json index 55f46c3..1228ee0 100644 --- a/utils/dataset.json +++ b/utils/dataset.json @@ -37,5 +37,6 @@ ], "BeijingAirQuality": ["data.dat", "desc.json"], "AirQuality": ["data.dat"], - "BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] + "BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"], + "NYCBike": ["NYC16x8.h5"] }