diff --git a/.vscode/launch.json b/.vscode/launch.json index 57e9984..1947dab 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -52,6 +52,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/SolarEnergy.yaml" }, + { + "name": "BeijingAirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BeijingAirQuality.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/.vscode/settings.json b/.vscode/settings.json index 2c67ab7..5503a78 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,11 @@ { - "python-envs.defaultEnvManager": "ms-python.python:system", - "python-envs.defaultPackageManager": "ms-python.python:pip", - "python-envs.pythonProjects": [] + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [ + { + "path": "data/SolarEnergy", + "envManager": "ms-python.python:system", + "packageManager": "ms-python.python:pip" + } + ] } \ No newline at end of file diff --git a/config/REPST/BeijingAirQuality.yaml b/config/REPST/BeijingAirQuality.yaml new file mode 100755 index 0000000..1e69e2a --- /dev/null +++ b/config/REPST/BeijingAirQuality.yaml @@ -0,0 +1,61 @@ +basic: + dataset: "BeijingAirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 12 + lag: 12 + normalizer: std + num_nodes: 7 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 3 + batch_size: 16 + +model: + pred_len: 12 + seq_len: 12 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 3 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index e6b1253..479fac0 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -7,6 +7,11 @@ def load_st_dataset(config): # sample = config["data"]["sample"] # output B, N, D match dataset: + case "BeijingAirQuality": + data_path = os.path.join("./data/BeijingAirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 36000, 7, 3 + data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 3806989..289c4b1 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module): return output class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model, patch_num, input_dim): + def __init__(self, c_in, d_model, patch_num, input_dim, output_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) - self.confusion_layer = nn.Linear(patch_num * input_dim, 1) + self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim) for m in self.modules(): if isinstance(m, nn.Conv1d): @@ -37,17 +37,16 @@ class TokenEmbedding(nn.Module): class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): + def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) - self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) + self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): - n_vars = x.shape[2] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 53a6046..3a3ce2d 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -19,6 +19,7 @@ class repst(nn.Module): self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) @@ -31,7 +32,7 @@ class repst(nn.Module): self.head_nf = self.d_ff * self.patch_nums # 词嵌入 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) @@ -41,7 +42,7 @@ class repst(nn.Module): self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) - self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), @@ -62,7 +63,7 @@ class repst(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - x = x[..., :1] + x = x[..., :self.output_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -79,25 +80,3 @@ class repst(nn.Module): return outputs -if __name__ == '__main__': - configs = { - 'device': 'cuda:0', - 'pred_len': 24, - 'seq_len': 24, - 'patch_len': 6, - 'stride': 7, - 'dropout': 0.2, - 'gpt_layers': 9, - 'd_ff': 128, - 'gpt_path': './GPT-2', - 'd_model': 64, - 'n_heads': 1, - 'input_dim': 1 - } - model = repst(configs) - x = torch.randn(16, 24, 325, 1) - y = model(x) - - print(y.shape) - - diff --git a/requirements.txt b/requirements.txt index d23694d..6199d59 100755 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ fastdtw notebook torchcde einops -transformers \ No newline at end of file +transformers +py7zr \ No newline at end of file diff --git a/run.py b/run.py index d9533a9..95867f4 100755 --- a/run.py +++ b/run.py @@ -15,7 +15,7 @@ def main(): args = init.init_device(args) init.init_seed(args["basic"]["seed"]) - + # Load model model = init.init_model(args) # Load dataset diff --git a/utils/Download_data.py b/utils/Download_data.py index 603ea01..75d17fe 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -1,228 +1,114 @@ -import os -import requests -import zipfile -import shutil -import kagglehub # 假设 kagglehub 是一个可用的库 +import os, json, shutil, requests +from urllib.parse import urlsplit from tqdm import tqdm +import kagglehub +import py7zr -# 定义文件完整性信息的字典 +# ---------- 1. 加载结构 JSON ---------- +def load_structure_json(path="utils/dataset.json"): + with open(path, "r", encoding="utf-8") as f: + return json.load(f) - -def check_and_download_data(): - """ - 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 - """ - current_working_dir = os.getcwd() # 获取当前工作目录 - data_dir = os.path.join( - current_working_dir, "data" - ) # 假设 data 文件夹在当前工作目录下 - - expected_structure = { - "PEMS03": [ - "PEMS03.csv", - "PEMS03.npz", - "PEMS03.txt", - "PEMS03_dtw_distance.npy", - "PEMS03_spatial_distance.npy", - ], - "PEMS04": [ - "PEMS04.csv", - "PEMS04.npz", - "PEMS04_dtw_distance.npy", - "PEMS04_spatial_distance.npy", - ], - "PEMS07": [ - "PEMS07.csv", - "PEMS07.npz", - "PEMS07_dtw_distance.npy", - "PEMS07_spatial_distance.npy", - ], - "PEMS08": [ - "PEMS08.csv", - "PEMS08.npz", - "PEMS08_dtw_distance.npy", - "PEMS08_spatial_distance.npy", - ], - "PEMS-BAY": [ - "adj_mx_bay.pkl", - "pems-bay-meta.h5", - "pems-bay.h5" - ], - "METR-LA": [ - "METR-LA.h5" - ], - "SolarEnergy": [ - - ] - } - - current_dir = os.getcwd() # 获取当前工作目录 - missing_adj = False - missing_main_files = False - - # 检查 data 文件夹是否存在 - if not os.path.exists(data_dir) or not os.path.isdir(data_dir): - # print(f"目录 {data_dir} 不存在。") - print("正在下载所有必要的数据文件...") - missing_adj = True - missing_main_files = True - else: - # 遍历预期的文件结构 - for subfolder, expected_files in expected_structure.items(): - subfolder_path = os.path.join(data_dir, subfolder) - - # 检查子文件夹是否存在 - if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): - # print(f"子文件夹 {subfolder} 不存在。") - missing_main_files = True +# ---------- 2. 检测完整性 ---------- +def detect_data_integrity(data_dir, expected, check_adj=False): + missing_adj, missing_main = False, False + if not os.path.isdir(data_dir): return True, True + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if not os.path.isdir(folder_path): + if check_adj: + missing_adj = True continue + missing_main = True + continue + existing = set(os.listdir(folder_path)) + for f in files: + if f not in existing: + if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")): + missing_adj = True + elif not check_adj: + missing_main = True + return missing_adj, missing_main - # 获取子文件夹中的实际文件列表 - actual_files = os.listdir(subfolder_path) +# ---------- 3. 下载 7z 并解压 ---------- +def download_and_extract(url, dst_dir, max_retries=3): + os.makedirs(dst_dir, exist_ok=True) + filename = os.path.basename(urlsplit(url).path) or "download.7z" + file_path = os.path.join(dst_dir, filename) + for attempt in range(1, max_retries+1): + try: + # 下载 + with requests.get(url, stream=True, timeout=30) as r: + r.raise_for_status() + total = int(r.headers.get("content-length",0)) + with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar: + for chunk in r.iter_content(8192): + f.write(chunk) + bar.update(len(chunk)) + # 解压 7z + with py7zr.SevenZipFile(file_path, mode='r') as archive: + archive.extractall(path=dst_dir) + os.remove(file_path) + return + except Exception as e: + if attempt==max_retries: raise RuntimeError("下载或解压失败") + print("错误,重试中...", e) - # 检查是否缺少文件 - for expected_file in expected_files: - if expected_file not in actual_files: - # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") - if ( - "_dtw_distance.npy" in expected_file - or "_spatial_distance.npy" in expected_file - ): - missing_adj = True - else: - missing_main_files = True +# ---------- 4. 下载 Kaggle 数据 ---------- +def download_kaggle_data(base_dir, dataset): + try: + path = kagglehub.dataset_download(dataset) + shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) + except Exception as e: + print("Kaggle 下载失败:", dataset, e) - # 根据缺失文件类型调用下载逻辑 +# ---------- 5. 整理目录 ---------- +def rearrange_dir(): + data_dir = os.path.join(os.getcwd(), "data") + nested = os.path.join(data_dir,"data") + if os.path.isdir(nested): + for item in os.listdir(nested): + src,dst = os.path.join(nested,item), os.path.join(data_dir,item) + if os.path.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录 + else: + shutil.copy2(src, dst) + shutil.rmtree(nested) + + for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]: + dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True) + for f in os.listdir(data_dir): + if kw in f.lower() and f.endswith((".h5",".pkl")): + shutil.move(os.path.join(data_dir,f), os.path.join(dst,f)) + + solar = os.path.join(data_dir,"solar-energy") + if os.path.isdir(solar): + dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True) + csv = os.path.join(solar,"solar_AL.csv") + if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv")) + shutil.rmtree(solar) + +# ---------- 6. 主流程 ---------- +def check_and_download_data(): + cwd = os.getcwd() + data_dir = os.path.join(cwd,"data") + expected = load_structure_json() + + missing_adj,_ = detect_data_integrity(data_dir, expected, check_adj=True) if missing_adj: - download_adj_data(current_dir) - if missing_main_files: - download_kaggle_data(current_dir, 'elmahy/pems-dataset') - download_kaggle_data(current_dir, 'scchuy/pemsbay') - download_kaggle_data(current_dir, "annnnguyen/metr-la-dataset") - download_kaggle_data(current_dir, "wangshaoqi/solar-energy") - - rearrange_dir() + download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) + baq_folder = os.path.join(data_dir,"BeijingAirQuality") + if not os.path.isdir(baq_folder): + download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) + + _,missing_main = detect_data_integrity(data_dir, expected, check_adj=False) + if missing_main: + for ds in ["elmahy/pems-dataset","scchuy/pemsbay","annnnguyen/metr-la-dataset","wangshaoqi/solar-energy"]: + download_kaggle_data(cwd, ds) + + rearrange_dir() return True - -def download_adj_data(current_dir, max_retries=3): - """ - 下载并解压 adj.zip 文件,并显示下载进度条。 - 如果下载失败,最多重试 max_retries 次。 - """ - url = "http://code.zhang-heng.com/static/adj.zip" - retries = 0 - - while retries <= max_retries: - try: - print(f"正在从 {url} 下载邻接矩阵文件...") - response = requests.get(url, stream=True) - - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1KB - t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度") - - zip_file_path = os.path.join(current_dir, "adj.zip") - with open(zip_file_path, "wb") as f: - for data in response.iter_content(block_size): - f.write(data) - t.update(len(data)) - t.close() - - # print("下载完成,文件已保存到:", zip_file_path) - - if os.path.exists(zip_file_path): - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(current_dir) - # print("数据集已解压到:", current_dir) - os.remove(zip_file_path) # 删除zip文件 - else: - print("未找到下载的zip文件,跳过解压。") - break # 下载成功,退出循环 - else: - print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") - except Exception as e: - print(f"下载或解压数据集时出错: {e}") - print("如果链接无效,请检查URL的合法性或稍后重试。") - - retries += 1 - if retries > max_retries: - raise Exception( - f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。" - ) - - -def download_kaggle_data(current_dir, kaggle_path): - """ - 下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。 - 如果目标文件夹已存在,会覆盖冲突的文件。 - """ - try: - print(f"正在下载 {kaggle_path} 数据集...") - path = kagglehub.dataset_download(kaggle_path) - # print("Path to KaggleHub dataset files:", path) - - if os.path.exists(path): - destination_path = os.path.join(current_dir, "data") - # 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件 - shutil.copytree(path, destination_path, dirs_exist_ok=True) - except Exception as e: - print(f"下载或处理 KaggleHub 数据集时出错: {e}") - - - -def rearrange_dir(): - """ - 将 data/data 中的文件合并到上级目录,并删除 data/data 目录。 - """ - data_dir = os.path.join(os.getcwd(), "data") - nested_data_dir = os.path.join(data_dir, "data") - - if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir): - for item in os.listdir(nested_data_dir): - source_path = os.path.join(nested_data_dir, item) - destination_path = os.path.join(data_dir, item) - - if os.path.isdir(source_path): - shutil.copytree(source_path, destination_path, dirs_exist_ok=True) - else: - shutil.copy2(source_path, destination_path) - - shutil.rmtree(nested_data_dir) - - # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 - pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") - os.makedirs(pems_bay_dir, exist_ok=True) - - for item in os.listdir(data_dir): - if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): - source_path = os.path.join(data_dir, item) - destination_path = os.path.join(pems_bay_dir, item) - shutil.move(source_path, destination_path) - - # metr-la - metrla_dir = os.path.join(data_dir, "METR-LA") - os.makedirs(metrla_dir, exist_ok=True) - for item in os.listdir(data_dir): - if "metr" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): - source_path = os.path.join(data_dir, item) - destination_path = os.path.join(metrla_dir, item) - shutil.move(source_path, destination_path) - - # solar-energy - solar_src = os.path.join(data_dir, "solar-energy") - solar_sub = os.path.join(solar_src, "solar_AL.txt") - solar_csv = os.path.join(solar_src, "solar_AL.csv") - solar_dst_dir = os.path.join(data_dir,"SolarEnergy") - solar_dst_csv = os.path.join(solar_dst_dir, "SolarEnergy.csv") - if os.path.isdir(solar_sub): shutil.rmtree(solar_sub) - if os.path.isdir(solar_src): os.rename(solar_src, solar_dst_dir) - if os.path.isfile(solar_csv.replace(solar_src, solar_dst_dir)): - os.rename(solar_csv.replace(solar_src, solar_dst_dir), solar_dst_csv) - - -# 主程序 -if __name__ == "__main__": +if __name__=="__main__": check_and_download_data() diff --git a/utils/dataset.json b/utils/dataset.json new file mode 100644 index 0000000..a3e6689 --- /dev/null +++ b/utils/dataset.json @@ -0,0 +1,39 @@ +{ + "PEMS03": [ + "PEMS03.csv", + "PEMS03.npz", + "PEMS03.txt", + "PEMS03_dtw_distance.npy", + "PEMS03_spatial_distance.npy" + ], + "PEMS04": [ + "PEMS04.csv", + "PEMS04.npz", + "PEMS04_dtw_distance.npy", + "PEMS04_spatial_distance.npy" + ], + "PEMS07": [ + "PEMS07.csv", + "PEMS07.npz", + "PEMS07_dtw_distance.npy", + "PEMS07_spatial_distance.npy" + ], + "PEMS08": [ + "PEMS08.csv", + "PEMS08.npz", + "PEMS08_dtw_distance.npy", + "PEMS08_spatial_distance.npy" + ], + "PEMS-BAY": [ + "adj_mx_bay.pkl", + "pems-bay-meta.h5", + "pems-bay.h5" + ], + "METR-LA": [ + "METR-LA.h5" + ], + "SolarEnergy": [ + "SolarEnergy.csv" + ], + "BeijingAirQuality": ["data.dat", "desc.json"] +}