From 76acc89499403920949aa4fe8b86875aca5ce7be Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 10:48:05 +0800 Subject: [PATCH 01/10] =?UTF-8?q?=E5=85=BC=E5=AE=B9METR-LA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 +++++ config/REPST/METR-LA.yaml | 60 +++++++++++++++++++++++++++++++++++++ dataloader/data_selector.py | 8 +++-- run.py | 2 ++ utils/Download_data.py | 15 ++++++++-- 5 files changed, 87 insertions(+), 6 deletions(-) create mode 100755 config/REPST/METR-LA.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index a8f3c00..281f8ac 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -36,6 +36,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMS-BAY.yaml" }, + { + "name": "REPST-METR", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/METR-LA.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml new file mode 100755 index 0000000..2e57a1c --- /dev/null +++ b/config/REPST/METR-LA.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "METR-LA" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 19fe7f5..78c3e3f 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -11,11 +11,13 @@ def load_st_dataset(config): data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: data = f['speed']['block0_values'][:] + case "METR-LA": + data_path = os.path.join("./data/METR-LA/METR-LA.h5") + with h5py.File(data_path, 'r') as f: + data = f['df']['block0_values'][:] case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD4": data_path = os.path.join("./data/PEMS04/PEMS04.npz") data = np.load(data_path)["data"][ diff --git a/run.py b/run.py index 175367f..4f48ce5 100755 --- a/run.py +++ b/run.py @@ -14,6 +14,8 @@ def main(): args = parse_args() args = init.init_device(args) init.init_seed(args["basic"]["seed"]) + + model = init.init_model(args) # Load dataset diff --git a/utils/Download_data.py b/utils/Download_data.py index fcd21f1..544d744 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -47,6 +47,9 @@ def check_and_download_data(): "adj_mx_bay.pkl", "pems-bay-meta.h5", "pems-bay.h5" + ], + "METR-LA": [ + "METR-LA.h5" ] } @@ -92,6 +95,7 @@ def check_and_download_data(): if missing_main_files: download_kaggle_data(current_dir, 'elmahy/pems-dataset') download_kaggle_data(current_dir, 'scchuy/pemsbay') + download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset") rearrange_dir() @@ -183,7 +187,6 @@ def rearrange_dir(): shutil.copy2(source_path, destination_path) shutil.rmtree(nested_data_dir) - # print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。") # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") @@ -195,10 +198,16 @@ def rearrange_dir(): destination_path = os.path.join(pems_bay_dir, item) shutil.move(source_path, destination_path) - # print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。") + # metr-la + metrla_dir = os.path.join(data_dir, "METR-LA") + os.makedirs(metrla_dir, exist_ok=True) + for item in os.listdir(data_dir): + if "metr" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): + source_path = os.path.join(data_dir, item) + destination_path = os.path.join(metrla_dir, item) + shutil.move(source_path, destination_path) # 主程序 if __name__ == "__main__": check_and_download_data() - # rearrange_dir() From a9e7cd5d3b6afd1f51464cb92d75c71ab2d0f62e Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 10:51:32 +0800 Subject: [PATCH 02/10] update data_selector --- dataloader/data_selector.py | 30 ++++++------------------------ 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 78c3e3f..3bcf3e5 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -20,40 +20,22 @@ def load_st_dataset(config): data = np.load(data_path)["data"][:, :, 0] case "PEMSD4": data_path = os.path.join("./data/PEMS04/PEMS04.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7": data_path = os.path.join("./data/PEMS07/PEMS07.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD8": data_path = os.path.join("./data/PEMS08/PEMS08.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(L)": data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(M)": data_path = os.path.join("./data/PEMS07(M)/V_228.csv") - data = np.genfromtxt( - data_path, delimiter="," - ) - case "METR-LA": - data_path = os.path.join("./data/METR-LA/METR.h5") - with h5py.File( - data_path, "r" - ) as f: - data = np.array(f["data"]) + data = np.genfromtxt(data_path, delimiter=",") case "BJ": data_path = os.path.join("./data/BJ/BJ500.csv") - data = np.genfromtxt( - data_path, delimiter=",", skip_header=1 - ) + data = np.genfromtxt(data_path, delimiter=",", skip_header=1) case "Hainan": data_path = os.path.join("./data/Hainan/Hainan.npz") data = np.load(data_path)["data"][:, :, 0] From 9911caa3d8d19b87eee93c4af1867ee49bfb62cc Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 11:28:58 +0800 Subject: [PATCH 03/10] =?UTF-8?q?=E5=85=BC=E5=AE=B9solarEnergy?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 +++++ config/REPST/SolarEnergy.yaml | 60 +++++++++++++++++++++++++++++++++++ dataloader/data_selector.py | 3 ++ utils/Download_data.py | 15 +++++++++ 4 files changed, 86 insertions(+) create mode 100755 config/REPST/SolarEnergy.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 281f8ac..57e9984 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -44,6 +44,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/METR-LA.yaml" }, + { + "name": "REPST-Solar", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/SolarEnergy.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml new file mode 100755 index 0000000..465e53d --- /dev/null +++ b/config/REPST/SolarEnergy.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "SolarEnergy" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 137 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 3bcf3e5..e6b1253 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -15,6 +15,9 @@ def load_st_dataset(config): data_path = os.path.join("./data/METR-LA/METR-LA.h5") with h5py.File(data_path, 'r') as f: data = f['df']['block0_values'][:] + case "SolarEnergy": + data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv") + data = np.loadtxt(data_path, delimiter=",") case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") data = np.load(data_path)["data"][:, :, 0] diff --git a/utils/Download_data.py b/utils/Download_data.py index 544d744..603ea01 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -50,6 +50,9 @@ def check_and_download_data(): ], "METR-LA": [ "METR-LA.h5" + ], + "SolarEnergy": [ + ] } @@ -96,6 +99,7 @@ def check_and_download_data(): download_kaggle_data(current_dir, 'elmahy/pems-dataset') download_kaggle_data(current_dir, 'scchuy/pemsbay') download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset") + download_kaggle_data(current_dir, "wangshaoqi/solar-energy") rearrange_dir() @@ -207,6 +211,17 @@ def rearrange_dir(): destination_path = os.path.join(metrla_dir, item) shutil.move(source_path, destination_path) + # solar-energy + solar_src = os.path.join(data_dir, "solar-energy") + solar_sub = os.path.join(solar_src, "solar_AL.txt") + solar_csv = os.path.join(solar_src, "solar_AL.csv") + solar_dst_dir = os.path.join(data_dir,"SolarEnergy") + solar_dst_csv = os.path.join(solar_dst_dir, "SolarEnergy.csv") + if os.path.isdir(solar_sub): shutil.rmtree(solar_sub) + if os.path.isdir(solar_src): os.rename(solar_src, solar_dst_dir) + if os.path.isfile(solar_csv.replace(solar_src, solar_dst_dir)): + os.rename(solar_csv.replace(solar_src, solar_dst_dir), solar_dst_csv) + # 主程序 if __name__ == "__main__": From a46edc79a528c6b301e94590912e4e46599f14d5 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 20:19:17 +0800 Subject: [PATCH 04/10] =?UTF-8?q?=E5=85=BC=E5=AE=B9BeijingAirQuality?= =?UTF-8?q?=E3=80=82=E9=87=8D=E6=9E=84data=EF=BC=8C=E9=9C=80=E8=A6=81?= =?UTF-8?q?=E6=9B=B4=E6=96=B0pip=20requirement?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 + .vscode/settings.json | 12 +- config/REPST/BeijingAirQuality.yaml | 61 ++++++ dataloader/data_selector.py | 5 + model/REPST/reprogramming.py | 9 +- model/REPST/repst.py | 29 +-- requirements.txt | 3 +- run.py | 2 +- utils/Download_data.py | 318 +++++++++------------------- utils/dataset.json | 39 ++++ 10 files changed, 235 insertions(+), 251 deletions(-) create mode 100755 config/REPST/BeijingAirQuality.yaml create mode 100644 utils/dataset.json diff --git a/.vscode/launch.json b/.vscode/launch.json index 57e9984..1947dab 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -52,6 +52,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/SolarEnergy.yaml" }, + { + "name": "BeijingAirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BeijingAirQuality.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/.vscode/settings.json b/.vscode/settings.json index 2c67ab7..5503a78 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,11 @@ { - "python-envs.defaultEnvManager": "ms-python.python:system", - "python-envs.defaultPackageManager": "ms-python.python:pip", - "python-envs.pythonProjects": [] + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [ + { + "path": "data/SolarEnergy", + "envManager": "ms-python.python:system", + "packageManager": "ms-python.python:pip" + } + ] } \ No newline at end of file diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality.yaml new file mode 100755 index 0000000..1e69e2a --- /dev/null +++ b/config/REPST/BeijingAirQuality.yaml @@ -0,0 +1,61 @@ +basic: + dataset: "BeijingAirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 7 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 3 + batch_size: 16 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 3 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index e6b1253..479fac0 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -7,6 +7,11 @@ def load_st_dataset(config): # sample = config["data"]["sample"] # output B, N, D match dataset: + case "BeijingAirQuality": + data_path = os.path.join("./data/BeijingAirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 36000, 7, 3 + data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 3806989..289c4b1 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module): return output class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model, patch_num, input_dim): + def __init__(self, c_in, d_model, patch_num, input_dim, output_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) - self.confusion_layer = nn.Linear(patch_num * input_dim, 1) + self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim) for m in self.modules(): if isinstance(m, nn.Conv1d): @@ -37,17 +37,16 @@ class TokenEmbedding(nn.Module): class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): + def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) - self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) + self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): - n_vars = x.shape[2] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 53a6046..3a3ce2d 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -19,6 +19,7 @@ class repst(nn.Module): self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) @@ -31,7 +32,7 @@ class repst(nn.Module): self.head_nf = self.d_ff * self.patch_nums # 词嵌入 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) @@ -41,7 +42,7 @@ class repst(nn.Module): self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) - self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), @@ -62,7 +63,7 @@ class repst(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - x = x[..., :1] + x = x[..., :self.output_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -79,25 +80,3 @@ class repst(nn.Module): return outputs -if __name__ == '__main__': - configs = { - 'device': 'cuda:0', - 'pred_len': 24, - 'seq_len': 24, - 'patch_len': 6, - 'stride': 7, - 'dropout': 0.2, - 'gpt_layers': 9, - 'd_ff': 128, - 'gpt_path': './GPT-2', - 'd_model': 64, - 'n_heads': 1, - 'input_dim': 1 - } - model = repst(configs) - x = torch.randn(16, 24, 325, 1) - y = model(x) - - print(y.shape) - - diff --git a/requirements.txt b/requirements.txt index d23694d..6199d59 100755 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ fastdtw notebook torchcde einops -transformers \ No newline at end of file +transformers +py7zr \ No newline at end of file diff --git a/run.py b/run.py index d9533a9..95867f4 100755 --- a/run.py +++ b/run.py @@ -15,7 +15,7 @@ def main(): args = init.init_device(args) init.init_seed(args["basic"]["seed"]) - + # Load model model = init.init_model(args) # Load dataset diff --git a/utils/Download_data.py b/utils/Download_data.py index 603ea01..75d17fe 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -1,228 +1,114 @@ -import os -import requests -import zipfile -import shutil -import kagglehub # 假设 kagglehub 是一个可用的库 +import os, json, shutil, requests +from urllib.parse import urlsplit from tqdm import tqdm +import kagglehub +import py7zr -# 定义文件完整性信息的字典 +# ---------- 1. 加载结构 JSON ---------- +def load_structure_json(path="utils/dataset.json"): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) - -def check_and_download_data(): - """ - 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 - """ - current_working_dir = os.getcwd() # 获取当前工作目录 - data_dir = os.path.join( - current_working_dir, "data" - ) # 假设 data 文件夹在当前工作目录下 - - expected_structure = { - "PEMS03": [ - "PEMS03.csv", - "PEMS03.npz", - "PEMS03.txt", - "PEMS03_dtw_distance.npy", - "PEMS03_spatial_distance.npy", - ], - "PEMS04": [ - "PEMS04.csv", - "PEMS04.npz", - "PEMS04_dtw_distance.npy", - "PEMS04_spatial_distance.npy", - ], - "PEMS07": [ - "PEMS07.csv", - "PEMS07.npz", - "PEMS07_dtw_distance.npy", - "PEMS07_spatial_distance.npy", - ], - "PEMS08": [ - "PEMS08.csv", - "PEMS08.npz", - "PEMS08_dtw_distance.npy", - "PEMS08_spatial_distance.npy", - ], - "PEMS-BAY": [ - "adj_mx_bay.pkl", - "pems-bay-meta.h5", - "pems-bay.h5" - ], - "METR-LA": [ - "METR-LA.h5" - ], - "SolarEnergy": [ - - ] - } - - current_dir = os.getcwd() # 获取当前工作目录 - missing_adj = False - missing_main_files = False - - # 检查 data 文件夹是否存在 - if not os.path.exists(data_dir) or not os.path.isdir(data_dir): - # print(f"目录 {data_dir} 不存在。") - print("正在下载所有必要的数据文件...") - missing_adj = True - missing_main_files = True - else: - # 遍历预期的文件结构 - for subfolder, expected_files in expected_structure.items(): - subfolder_path = os.path.join(data_dir, subfolder) - - # 检查子文件夹是否存在 - if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): - # print(f"子文件夹 {subfolder} 不存在。") - missing_main_files = True +# ---------- 2. 检测完整性 ---------- +def detect_data_integrity(data_dir, expected, check_adj=False): + missing_adj, missing_main = False, False + if not os.path.isdir(data_dir): return True, True + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if not os.path.isdir(folder_path): + if check_adj: + missing_adj = True continue + missing_main = True + continue + existing = set(os.listdir(folder_path)) + for f in files: + if f not in existing: + if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): + missing_adj = True + elif not check_adj: + missing_main = True + return missing_adj, missing_main - # 获取子文件夹中的实际文件列表 - actual_files = os.listdir(subfolder_path) +# ---------- 3. 下载 7z 并解压 ---------- +def download_and_extract(url, dst_dir, max_retries=3): + os.makedirs(dst_dir, exist_ok=True) + filename = os.path.basename(urlsplit(url).path) or "download.7z" + file_path = os.path.join(dst_dir, filename) + for attempt in range(1, max_retries+1): + try: + # 下载 + with requests.get(url, stream=True, timeout=30) as r: + r.raise_for_status() + total = int(r.headers.get("content-length",0)) + with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar: + for chunk in r.iter_content(8192): + f.write(chunk) + bar.update(len(chunk)) + # 解压 7z + with py7zr.SevenZipFile(file_path, mode='r') as archive: + archive.extractall(path=dst_dir) + os.remove(file_path) + return + except Exception as e: + if attempt==max_retries: raise RuntimeError("下载或解压失败") + print("错误,重试中...", e) - # 检查是否缺少文件 - for expected_file in expected_files: - if expected_file not in actual_files: - # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") - if ( - "_dtw_distance.npy" in expected_file - or "_spatial_distance.npy" in expected_file - ): - missing_adj = True - else: - missing_main_files = True +# ---------- 4. 下载 Kaggle 数据 ---------- +def download_kaggle_data(base_dir, dataset): + try: + path = kagglehub.dataset_download(dataset) + shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) + except Exception as e: + print("Kaggle 下载失败:", dataset, e) - # 根据缺失文件类型调用下载逻辑 +# ---------- 5. 整理目录 ---------- +def rearrange_dir(): + data_dir = os.path.join(os.getcwd(), "data") + nested = os.path.join(data_dir,"data") + if os.path.isdir(nested): + for item in os.listdir(nested): + src,dst = os.path.join(nested,item), os.path.join(data_dir,item) + if os.path.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录 + else: + shutil.copy2(src, dst) + shutil.rmtree(nested) + + for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]: + dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True) + for f in os.listdir(data_dir): + if kw in f.lower() and f.endswith((".h5",".pkl")): + shutil.move(os.path.join(data_dir,f), os.path.join(dst,f)) + + solar = os.path.join(data_dir,"solar-energy") + if os.path.isdir(solar): + dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True) + csv = os.path.join(solar,"solar_AL.csv") + if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv")) + shutil.rmtree(solar) + +# ---------- 6. 主流程 ---------- +def check_and_download_data(): + cwd = os.getcwd() + data_dir = os.path.join(cwd,"data") + expected = load_structure_json() + + missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) if missing_adj: - download_adj_data(current_dir) - if missing_main_files: - download_kaggle_data(current_dir, 'elmahy/pems-dataset') - download_kaggle_data(current_dir, 'scchuy/pemsbay') - download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset") - download_kaggle_data(current_dir, "wangshaoqi/solar-energy") - - rearrange_dir() + download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) + baq_folder = os.path.join(data_dir,"BeijingAirQuality") + if not os.path.isdir(baq_folder): + download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) + + _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) + if missing_main: + for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: + download_kaggle_data(cwd, ds) + + rearrange_dir() return True - -def download_adj_data(current_dir, max_retries=3): - """ - 下载并解压 adj.zip 文件,并显示下载进度条。 - 如果下载失败,最多重试 max_retries 次。 - """ - url = "http://code.zhang-heng.com/static/adj.zip" - retries = 0 - - while retries <= max_retries: - try: - print(f"正在从 {url} 下载邻接矩阵文件...") - response = requests.get(url, stream=True) - - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1KB - t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度") - - zip_file_path = os.path.join(current_dir, "adj.zip") - with open(zip_file_path, "wb") as f: - for data in response.iter_content(block_size): - f.write(data) - t.update(len(data)) - t.close() - - # print("下载完成,文件已保存到:", zip_file_path) - - if os.path.exists(zip_file_path): - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(current_dir) - # print("数据集已解压到:", current_dir) - os.remove(zip_file_path) # 删除zip文件 - else: - print("未找到下载的zip文件,跳过解压。") - break # 下载成功,退出循环 - else: - print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") - except Exception as e: - print(f"下载或解压数据集时出错: {e}") - print("如果链接无效,请检查URL的合法性或稍后重试。") - - retries += 1 - if retries > max_retries: - raise Exception( - f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。" - ) - - -def download_kaggle_data(current_dir, kaggle_path): - """ - 下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。 - 如果目标文件夹已存在,会覆盖冲突的文件。 - """ - try: - print(f"正在下载 {kaggle_path} 数据集...") - path = kagglehub.dataset_download(kaggle_path) - # print("Path to KaggleHub dataset files:", path) - - if os.path.exists(path): - destination_path = os.path.join(current_dir, "data") - # 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件 - shutil.copytree(path, destination_path, dirs_exist_ok=True) - except Exception as e: - print(f"下载或处理 KaggleHub 数据集时出错: {e}") - - - -def rearrange_dir(): - """ - 将 data/data 中的文件合并到上级目录,并删除 data/data 目录。 - """ - data_dir = os.path.join(os.getcwd(), "data") - nested_data_dir = os.path.join(data_dir, "data") - - if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir): - for item in os.listdir(nested_data_dir): - source_path = os.path.join(nested_data_dir, item) - destination_path = os.path.join(data_dir, item) - - if os.path.isdir(source_path): - shutil.copytree(source_path, destination_path, dirs_exist_ok=True) - else: - shutil.copy2(source_path, destination_path) - - shutil.rmtree(nested_data_dir) - - # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 - pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") - os.makedirs(pems_bay_dir, exist_ok=True) - - for item in os.listdir(data_dir): - if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): - source_path = os.path.join(data_dir, item) - destination_path = os.path.join(pems_bay_dir, item) - shutil.move(source_path, destination_path) - - # metr-la - metrla_dir = os.path.join(data_dir, "METR-LA") - os.makedirs(metrla_dir, exist_ok=True) - for item in os.listdir(data_dir): - if "metr" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): - source_path = os.path.join(data_dir, item) - destination_path = os.path.join(metrla_dir, item) - shutil.move(source_path, destination_path) - - # solar-energy - solar_src = os.path.join(data_dir, "solar-energy") - solar_sub = os.path.join(solar_src, "solar_AL.txt") - solar_csv = os.path.join(solar_src, "solar_AL.csv") - solar_dst_dir = os.path.join(data_dir,"SolarEnergy") - solar_dst_csv = os.path.join(solar_dst_dir, "SolarEnergy.csv") - if os.path.isdir(solar_sub): shutil.rmtree(solar_sub) - if os.path.isdir(solar_src): os.rename(solar_src, solar_dst_dir) - if os.path.isfile(solar_csv.replace(solar_src, solar_dst_dir)): - os.rename(solar_csv.replace(solar_src, solar_dst_dir), solar_dst_csv) - - -# 主程序 -if __name__ == "__main__": +if __name__=="__main__": check_and_download_data() diff --git a/utils/dataset.json b/utils/dataset.json new file mode 100644 index 0000000..a3e6689 --- /dev/null +++ b/utils/dataset.json @@ -0,0 +1,39 @@ +{ + "PEMS03": [ + "PEMS03.csv", + "PEMS03.npz", + "PEMS03.txt", + "PEMS03_dtw_distance.npy", + "PEMS03_spatial_distance.npy" + ], + "PEMS04": [ + "PEMS04.csv", + "PEMS04.npz", + "PEMS04_dtw_distance.npy", + "PEMS04_spatial_distance.npy" + ], + "PEMS07": [ + "PEMS07.csv", + "PEMS07.npz", + "PEMS07_dtw_distance.npy", + "PEMS07_spatial_distance.npy" + ], + "PEMS08": [ + "PEMS08.csv", + "PEMS08.npz", + "PEMS08_dtw_distance.npy", + "PEMS08_spatial_distance.npy" + ], + "PEMS-BAY": [ + "adj_mx_bay.pkl", + "pems-bay-meta.h5", + "pems-bay.h5" + ], + "METR-LA": [ + "METR-LA.h5" + ], + "SolarEnergy": [ + "SolarEnergy.csv" + ], + "BeijingAirQuality": ["data.dat", "desc.json"] +} From 96f2ea1239a17086a7f3993a9ffd37a8fcfa8087 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 20:50:35 +0800 Subject: [PATCH 05/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=8F=8D=E5=BD=92?= =?UTF-8?q?=E4=B8=80=E5=8C=96=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- config/REPST/BeijingAirQuality.yaml | 8 ++++---- config/REPST/METR-LA.yaml | 12 ++++++------ config/REPST/SolarEnergy.yaml | 8 ++++---- model/REPST/normalizer.py | 2 -- model/REPST/reprogramming.py | 5 ----- trainer/Trainer.py | 21 +++++++++------------ 6 files changed, 23 insertions(+), 33 deletions(-) diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality.yaml index 1e69e2a..595c971 100755 --- a/config/REPST/BeijingAirQuality.yaml +++ b/config/REPST/BeijingAirQuality.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 7 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml index 2e57a1c..68340d1 100755 --- a/config/REPST/METR-LA.yaml +++ b/config/REPST/METR-LA.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 207 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 @@ -41,7 +41,7 @@ train: batch_size: 16 early_stop: true early_stop_patience: 15 - epochs: 100 + epochs: 1 grad_norm: false loss_func: mae lr_decay: true @@ -52,7 +52,7 @@ train: real_value: true weight_decay: 0 debug: false - output_dim: 1 + output_dim: 100 log_step: 1000 plot: false mae_thresh: None diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml index 465e53d..282c929 100755 --- a/config/REPST/SolarEnergy.yaml +++ b/config/REPST/SolarEnergy.yaml @@ -11,8 +11,8 @@ data: column_wise: false days_per_week: 7 default_graph: true - horizon: 12 - lag: 12 + horizon: 24 + lag: 24 normalizer: std num_nodes: 137 steps_per_day: 288 @@ -24,8 +24,8 @@ data: batch_size: 16 model: - pred_len: 12 - seq_len: 12 + pred_len: 24 + seq_len: 24 patch_len: 6 stride: 7 dropout: 0.2 diff --git a/model/REPST/normalizer.py b/model/REPST/normalizer.py index fb7e182..c112c7a 100644 --- a/model/REPST/normalizer.py +++ b/model/REPST/normalizer.py @@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module): return self.gumbel_softmax(logits, 1, self.k, self.hard) def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): - y_soft = F.gumbel_softmax(logits, tau, hard) - if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 289c4b1..1ba7976 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -51,7 +51,6 @@ class PatchEmbedding(nn.Module): x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x_value_embed = self.value_embedding(x) - return self.dropout(x_value_embed), n_vars class ReprogrammingLayer(nn.Module): @@ -83,13 +82,9 @@ class ReprogrammingLayer(nn.Module): def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape - scale = 1. / sqrt(E) - scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) - A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) - return reprogramming_embedding \ No newline at end of file diff --git a/trainer/Trainer.py b/trainer/Trainer.py index c7abff8..82c90f5 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -204,8 +204,8 @@ class Trainer: # 累积损失和预测结果 total_loss += loss.item() - y_pred.append(output.detach().cpu()) - y_true.append(label.detach().cpu()) + y_pred.append(d_output.detach().cpu()) + y_true.append(d_label.detach().cpu()) # 更新进度条 progress_bar.set_postfix(loss=d_loss.item()) @@ -356,18 +356,15 @@ class Trainer: y_pred.append(output) y_true.append(label) - # 合并所有批次的预测结果 - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) # 计算并记录每个时间步的指标 - for t in range(y_true.shape[1]): + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) @@ -377,7 +374,7 @@ class Trainer: # 计算并记录平均指标 mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] + d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] ) logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" From 4d8708714738d6987b4b1b01260d36b2acc8357c Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 21:21:39 +0800 Subject: [PATCH 06/10] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E6=97=B6=E6=B1=87?= =?UTF-8?q?=E6=80=BB=E6=A0=B7=E6=9C=AC=20=E4=BD=BF=E7=94=A8detach=E5=88=B0?= =?UTF-8?q?gpu=20=E9=81=BF=E5=85=8D=E6=98=BE=E5=AD=98=E7=88=86=E7=82=B8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- trainer/Trainer.py | 26 ++++++++------------------ 1 file changed, 8 insertions(+), 18 deletions(-) diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 82c90f5..508230e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -316,13 +316,9 @@ class Trainer: def _log_model_params(self): """输出模型可训练参数数量""" - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -353,8 +349,8 @@ class Trainer: for data, target in data_loader: label = target[..., : args["output_dim"]] output = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) @@ -368,17 +364,11 @@ class Trainer: args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics( - d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): From 7055a6da649f9a460704e207781f9e093fb24d44 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 22:15:48 +0800 Subject: [PATCH 07/10] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=AD=A3=E7=A1=AE?= =?UTF-8?q?=E7=9A=84AirQuality=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 +++ config/REPST/AirQuality.yaml | 61 +++++++++++++++++++ ...aml => BeijingAirQuality(Deprecated).yaml} | 0 dataloader/data_selector.py | 5 ++ model/REPST/repst.py | 13 ++-- utils/Download_data.py | 3 +- utils/dataset.json | 3 +- 7 files changed, 85 insertions(+), 8 deletions(-) create mode 100755 config/REPST/AirQuality.yaml rename config/REPST/{BeijingAirQuality.yaml => BeijingAirQuality(Deprecated).yaml} (100%) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1947dab..96f2427 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -60,6 +60,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/BeijingAirQuality.yaml" }, + { + "name": "AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/AirQuality.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml new file mode 100755 index 0000000..b0683ae --- /dev/null +++ b/config/REPST/AirQuality.yaml @@ -0,0 +1,61 @@ +basic: + dataset: "AirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 6 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 6 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality(Deprecated).yaml similarity index 100% rename from config/REPST/BeijingAirQuality.yaml rename to config/REPST/BeijingAirQuality(Deprecated).yaml diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 479fac0..45987d6 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -12,6 +12,11 @@ def load_st_dataset(config): data = np.memmap(data_path, dtype=np.float32, mode='r') L, N, C = 36000, 7, 3 data = data.reshape(L, N, C) + case "AirQuality": + data_path = os.path.join("./data/AirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 8701,35,6 + data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 3a3ce2d..6ceeb2a 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -47,7 +47,7 @@ class repst(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): @@ -63,7 +63,7 @@ class repst(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - x = x[..., :self.output_dim] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -73,10 +73,11 @@ class repst(nn.Module): enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) + dec_out = self.out_mlp(enc_out) #[B, N, T*C] + + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs diff --git a/utils/Download_data.py b/utils/Download_data.py index 75d17fe..beec5de 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -99,7 +99,8 @@ def check_and_download_data(): download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) baq_folder = os.path.join(data_dir,"BeijingAirQuality") - if not os.path.isdir(baq_folder): + baq_folder2 = os.path.join(data_dir,"AirQuality") + if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2): download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) diff --git a/utils/dataset.json b/utils/dataset.json index a3e6689..a778eff 100644 --- a/utils/dataset.json +++ b/utils/dataset.json @@ -35,5 +35,6 @@ "SolarEnergy": [ "SolarEnergy.csv" ], - "BeijingAirQuality": ["data.dat", "desc.json"] + "BeijingAirQuality": ["data.dat", "desc.json"], + "AirQuality": ["data.dat"] } From d06cf4e0aab3608b6d99dcd36cf41024708fb8b5 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Fri, 21 Nov 2025 09:32:52 +0800 Subject: [PATCH 08/10] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E9=85=8D=E7=BD=AE?= =?UTF-8?q?=E9=94=99=E8=AF=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 4 ++-- config/REPST/METR-LA.yaml | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index 5503a78..2201e74 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -4,8 +4,8 @@ "python-envs.pythonProjects": [ { "path": "data/SolarEnergy", - "envManager": "ms-python.python:system", - "packageManager": "ms-python.python:pip" + "envManager": "ms-python.python:conda", + "packageManager": "ms-python.python:conda" } ] } \ No newline at end of file diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml index 68340d1..1e3e29d 100755 --- a/config/REPST/METR-LA.yaml +++ b/config/REPST/METR-LA.yaml @@ -41,7 +41,7 @@ train: batch_size: 16 early_stop: true early_stop_patience: 15 - epochs: 1 + epochs: 100 grad_norm: false loss_func: mae lr_decay: true @@ -52,7 +52,7 @@ train: real_value: true weight_decay: 0 debug: false - output_dim: 100 + output_dim: 1 log_step: 1000 plot: false mae_thresh: None From b7ea73bc922c7e678ab120a631bc074f42503de9 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 23 Nov 2025 17:58:44 +0800 Subject: [PATCH 09/10] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E9=9B=86=E7=9A=84=E6=A3=80=E6=B5=8B=E4=B8=8E=E4=B8=8B=E8=BD=BD?= =?UTF-8?q?=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 9 ++- utils/Download_data.py | 128 +++++++++++++++++++++++++++++++++-------- utils/dataset.json | 3 +- 3 files changed, 115 insertions(+), 25 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index 96f2427..771325c 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,7 +4,14 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ -{ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, + { "name": "STID_PEMS-BAY", "type": "debugpy", "request": "launch", diff --git a/utils/Download_data.py b/utils/Download_data.py index beec5de..e0ce6ce 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -10,25 +10,48 @@ def load_structure_json(path="utils/dataset.json"): return json.load(f) # ---------- 2. 检测完整性 ---------- -def detect_data_integrity(data_dir, expected, check_adj=False): - missing_adj, missing_main = False, False - if not os.path.isdir(data_dir): return True, True +def detect_data_integrity(data_dir, expected): + missing_list = [] + if not os.path.isdir(data_dir): + # 如果数据目录不存在,则所有数据集都缺失 + missing_list.extend(expected.keys()) + # 标记adj也缺失 + missing_list.append("adj") + return missing_list + + # 检查adj相关文件(距离矩阵文件) + has_missing_adj = False + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if os.path.isdir(folder_path): + existing = set(os.listdir(folder_path)) + for f in files: + if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_adj = True + break + if has_missing_adj: + missing_list.append("adj") + + # 检查数据集主文件 for folder, files in expected.items(): folder_path = os.path.join(data_dir, folder) if not os.path.isdir(folder_path): - if check_adj: - missing_adj = True - continue - missing_main = True + missing_list.append(folder) continue + existing = set(os.listdir(folder_path)) + has_missing_file = False + for f in files: - if f not in existing: - if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): - missing_adj = True - elif not check_adj: - missing_main = True - return missing_adj, missing_main + # 跳过距离矩阵文件,已经在上面检查过了 + if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_file = True + + if has_missing_file and folder not in missing_list: + missing_list.append(folder) + + # print(f"缺失数据集:{missing_list}") + return missing_list # ---------- 3. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): @@ -57,11 +80,34 @@ def download_and_extract(url, dst_dir, max_retries=3): # ---------- 4. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: + print(f"Downloading kaggle dataset : {dataset}") path = kagglehub.dataset_download(dataset) shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) except Exception as e: print("Kaggle 下载失败:", dataset, e) +# ---------- 5. 下载 GitHub 数据 ---------- +def download_github_data(file_path, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + # raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + response = requests.head(raw_url, allow_redirects=True) + if response.status_code != 200: + print(f"Failed to get file size for {raw_url}. Status code:", response.status_code) + return + + file_size = int(response.headers.get('Content-Length', 0)) + response = requests.get(raw_url, stream=True, allow_redirects=True) + file_name = os.path.basename(file_path) + file_path_to_save = os.path.join(save_dir, file_name) + with open(file_path_to_save, 'wb') as f: + with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + # ---------- 5. 整理目录 ---------- def rearrange_dir(): data_dir = os.path.join(os.getcwd(), "data") @@ -92,21 +138,57 @@ def rearrange_dir(): def check_and_download_data(): cwd = os.getcwd() data_dir = os.path.join(cwd,"data") - expected = load_structure_json() + file_tree = load_structure_json() - missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) - if missing_adj: + # 执行一次检测,获取所有缺失项 + missing_list = detect_data_integrity(data_dir, file_tree) + print(f"缺失数据集:{missing_list}") + + # 检查并下载adj数据 + if "adj" in missing_list: download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) + # 下载后从缺失列表中移除adj + missing_list.remove("adj") - baq_folder = os.path.join(data_dir,"BeijingAirQuality") - baq_folder2 = os.path.join(data_dir,"AirQuality") - if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2): + # 检查BeijingAirQuality和AirQuality + if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list: download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) + + # 检查并下载TaxiBJ数据 + if "TaxiBJ" in missing_list: + taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi") + taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy'] + for file in taxibj_files: + file_path = f"Datasets/TaxiBJ/{file}" + download_github_data(file_path, taxi_bj_floder) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) - _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) - if missing_main: - for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: - download_kaggle_data(cwd, ds) + # 检查并下载pems, bay, metr-la, solar-energy数据 + # 定义数据集名称到Kaggle数据集的映射 + kaggle_map = { + "PEMS03": "elmahy/pems-dataset", + "PEMS04": "elmahy/pems-dataset", + "PEMS07": "elmahy/pems-dataset", + "PEMS08": "elmahy/pems-dataset", + "PEMS-BAY": "scchuy/pemsbay", + "METR-LA": "annnnguyen/metr-la-dataset", + "SolarEnergy": "wangshaoqi/solar-energy" + } + + # 检查是否有需要从Kaggle下载的数据集 + # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 + downloaded_kaggle_datasets = set() + + for dataset, kaggle_ds in kaggle_map.items(): + if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets: + download_kaggle_data(cwd, kaggle_ds) + # 将已下载的数据集添加到集合中 + downloaded_kaggle_datasets.add(kaggle_ds) + # 下载一个数据集后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) rearrange_dir() return True diff --git a/utils/dataset.json b/utils/dataset.json index a778eff..55f46c3 100644 --- a/utils/dataset.json +++ b/utils/dataset.json @@ -36,5 +36,6 @@ "SolarEnergy.csv" ], "BeijingAirQuality": ["data.dat", "desc.json"], - "AirQuality": ["data.dat"] + "AirQuality": ["data.dat"], + "BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] } From 475a4788cd0a69eee2dea42963bc8330b9a73788 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Sun, 23 Nov 2025 19:04:50 +0800 Subject: [PATCH 10/10] =?UTF-8?q?=E5=85=BC=E5=AE=B9BJTaxi?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 15 +++++---- config/REPST/BJTaxi-Inflow.yaml | 60 +++++++++++++++++++++++++++++++++ dataloader/data_selector.py | 17 ++++++++++ utils/Download_data.py | 22 +++++------- 4 files changed, 93 insertions(+), 21 deletions(-) create mode 100755 config/REPST/BJTaxi-Inflow.yaml diff --git a/.vscode/launch.json b/.vscode/launch.json index 771325c..4f992f0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,13 +4,6 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ - { - "name": "Python Debugger: Current File", - "type": "debugpy", - "request": "launch", - "program": "${file}", - "console": "integratedTerminal" - }, { "name": "STID_PEMS-BAY", "type": "debugpy", @@ -35,6 +28,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMSD8.yaml" }, + { + "name": "REPST-BJTaxi-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" + }, { "name": "REPST-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml new file mode 100755 index 0000000..37577a2 --- /dev/null +++ b/config/REPST/BJTaxi-Inflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "BJTaxi-InFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 45987d6..224b6fc 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -55,6 +55,10 @@ def load_st_dataset(config): case "SD": data_path = os.path.join("./data/SD/data.npz") data = np.load(data_path)["data"][:, :, 0].astype(np.float32) + case "BJTaxi-InFlow": + data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) + case "BJTaxi-OutFlow": + data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) case _: raise ValueError(f"Unsupported dataset: {dataset}") @@ -65,3 +69,16 @@ def load_st_dataset(config): print("加载 %s 数据集中... " % dataset) # return data[::sample] return data + +def read_BeijingTaxi(): + files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", + "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] + all_data = [] + for file in files: + data_path = os.path.join(f"./data/BeijingTaxi/{file}") + data = np.load(data_path) + all_data.append(data) + all_data = np.concatenate(all_data, axis=0) + time_num = all_data.shape[0] + all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) + return all_data \ No newline at end of file diff --git a/utils/Download_data.py b/utils/Download_data.py index e0ce6ce..539a804 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -4,12 +4,8 @@ from tqdm import tqdm import kagglehub import py7zr -# ---------- 1. 加载结构 JSON ---------- -def load_structure_json(path="utils/dataset.json"): - with open(path, "r", encoding="utf-8") as f: - return json.load(f) -# ---------- 2. 检测完整性 ---------- +# ---------- 1. 检测完整性 ---------- def detect_data_integrity(data_dir, expected): missing_list = [] if not os.path.isdir(data_dir): @@ -53,7 +49,7 @@ def detect_data_integrity(data_dir, expected): # print(f"缺失数据集:{missing_list}") return missing_list -# ---------- 3. 下载 7z 并解压 ---------- +# ---------- 2. 下载 7z 并解压 ---------- def download_and_extract(url, dst_dir, max_retries=3): os.makedirs(dst_dir, exist_ok=True) filename = os.path.basename(urlsplit(url).path) or "download.7z" @@ -77,7 +73,7 @@ def download_and_extract(url, dst_dir, max_retries=3): if attempt==max_retries: raise RuntimeError("下载或解压失败") print("错误,重试中...", e) -# ---------- 4. 下载 Kaggle 数据 ---------- +# ---------- 3. 下载 Kaggle 数据 ---------- def download_kaggle_data(base_dir, dataset): try: print(f"Downloading kaggle dataset : {dataset}") @@ -86,7 +82,7 @@ def download_kaggle_data(base_dir, dataset): except Exception as e: print("Kaggle 下载失败:", dataset, e) -# ---------- 5. 下载 GitHub 数据 ---------- +# ---------- 4. 下载 GitHub 数据 ---------- def download_github_data(file_path, save_dir): if not os.path.exists(save_dir): os.makedirs(save_dir) @@ -136,13 +132,13 @@ def rearrange_dir(): # ---------- 6. 主流程 ---------- def check_and_download_data(): + # 加载结构文件,检测缺失数据集 cwd = os.getcwd() data_dir = os.path.join(cwd,"data") - file_tree = load_structure_json() - - # 执行一次检测,获取所有缺失项 + with open("utils/dataset.json", "r", encoding="utf-8") as f: + file_tree = json.load(f) missing_list = detect_data_integrity(data_dir, file_tree) - print(f"缺失数据集:{missing_list}") + # print(f"缺失数据集:{missing_list}") # 检查并下载adj数据 if "adj" in missing_list: @@ -167,7 +163,6 @@ def check_and_download_data(): missing_list = detect_data_integrity(data_dir, file_tree) # 检查并下载pems, bay, metr-la, solar-energy数据 - # 定义数据集名称到Kaggle数据集的映射 kaggle_map = { "PEMS03": "elmahy/pems-dataset", "PEMS04": "elmahy/pems-dataset", @@ -178,7 +173,6 @@ def check_and_download_data(): "SolarEnergy": "wangshaoqi/solar-energy" } - # 检查是否有需要从Kaggle下载的数据集 # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 downloaded_kaggle_datasets = set()