From 7055a6da649f9a460704e207781f9e093fb24d44 Mon Sep 17 00:00:00 2001 From: czzhangheng Date: Thu, 20 Nov 2025 22:15:48 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E6=AD=A3=E7=A1=AE=E7=9A=84Ai?= =?UTF-8?q?rQuality=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 8 +++ config/REPST/AirQuality.yaml | 61 +++++++++++++++++++ ...aml => BeijingAirQuality(Deprecated).yaml} | 0 dataloader/data_selector.py | 5 ++ model/REPST/repst.py | 13 ++-- utils/Download_data.py | 3 +- utils/dataset.json | 3 +- 7 files changed, 85 insertions(+), 8 deletions(-) create mode 100755 config/REPST/AirQuality.yaml rename config/REPST/{BeijingAirQuality.yaml => BeijingAirQuality(Deprecated).yaml} (100%) diff --git a/.vscode/launch.json b/.vscode/launch.json index 1947dab..96f2427 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -60,6 +60,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/BeijingAirQuality.yaml" }, + { + "name": "AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/AirQuality.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml new file mode 100755 index 0000000..b0683ae --- /dev/null +++ b/config/REPST/AirQuality.yaml @@ -0,0 +1,61 @@ +basic: + dataset: "AirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 6 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 6 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality(Deprecated).yaml similarity index 100% rename from config/REPST/BeijingAirQuality.yaml rename to config/REPST/BeijingAirQuality(Deprecated).yaml diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 479fac0..45987d6 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -12,6 +12,11 @@ def load_st_dataset(config): data = np.memmap(data_path, dtype=np.float32, mode='r') L, N, C = 36000, 7, 3 data = data.reshape(L, N, C) + case "AirQuality": + data_path = os.path.join("./data/AirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 8701,35,6 + data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 3a3ce2d..6ceeb2a 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -47,7 +47,7 @@ class repst(nn.Module): self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): @@ -63,7 +63,7 @@ class repst(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - x = x[..., :self.output_dim] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -73,10 +73,11 @@ class repst(nn.Module): enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) + dec_out = self.out_mlp(enc_out) #[B, N, T*C] + + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs diff --git a/utils/Download_data.py b/utils/Download_data.py index 75d17fe..beec5de 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -99,7 +99,8 @@ def check_and_download_data(): download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) baq_folder = os.path.join(data_dir,"BeijingAirQuality") - if not os.path.isdir(baq_folder): + baq_folder2 = os.path.join(data_dir,"AirQuality") + if not os.path.isdir(baq_folder) or not os.path.isdir(baq_folder2): download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) diff --git a/utils/dataset.json b/utils/dataset.json index a3e6689..a778eff 100644 --- a/utils/dataset.json +++ b/utils/dataset.json @@ -35,5 +35,6 @@ "SolarEnergy": [ "SolarEnergy.csv" ], - "BeijingAirQuality": ["data.dat", "desc.json"] + "BeijingAirQuality": ["data.dat", "desc.json"], + "AirQuality": ["data.dat"] }