diff --git a/.vscode/launch.json b/.vscode/launch.json index a8f3c00..4f992f0 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,7 +4,7 @@ // 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ -{ + { "name": "STID_PEMS-BAY", "type": "debugpy", "request": "launch", @@ -28,6 +28,14 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMSD8.yaml" }, + { + "name": "REPST-BJTaxi-InFlow", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BJTaxi-Inflow.yaml" + }, { "name": "REPST-PEMSBAY", "type": "debugpy", @@ -36,6 +44,38 @@ "console": "integratedTerminal", "args": "--config ./config/REPST/PEMS-BAY.yaml" }, + { + "name": "REPST-METR", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/METR-LA.yaml" + }, + { + "name": "REPST-Solar", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/SolarEnergy.yaml" + }, + { + "name": "BeijingAirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/BeijingAirQuality.yaml" + }, + { + "name": "AirQuality", + "type": "debugpy", + "request": "launch", + "program": "run.py", + "console": "integratedTerminal", + "args": "--config ./config/REPST/AirQuality.yaml" + }, { "name": "AEPSA-PEMSBAY", "type": "debugpy", diff --git a/.vscode/settings.json b/.vscode/settings.json index 2c67ab7..2201e74 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,5 +1,11 @@ { - "python-envs.defaultEnvManager": "ms-python.python:system", - "python-envs.defaultPackageManager": "ms-python.python:pip", - "python-envs.pythonProjects": [] + "python-envs.defaultEnvManager": "ms-python.python:conda", + "python-envs.defaultPackageManager": "ms-python.python:conda", + "python-envs.pythonProjects": [ + { + "path": "data/SolarEnergy", + "envManager": "ms-python.python:conda", + "packageManager": "ms-python.python:conda" + } + ] } \ No newline at end of file diff --git a/config/REPST/AirQuality.yaml b/config/REPST/AirQuality.yaml new file mode 100755 index 0000000..b0683ae --- /dev/null +++ b/config/REPST/AirQuality.yaml @@ -0,0 +1,61 @@ +basic: + dataset: "AirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 35 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 6 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 6 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/BJTaxi-Inflow.yaml b/config/REPST/BJTaxi-Inflow.yaml new file mode 100755 index 0000000..37577a2 --- /dev/null +++ b/config/REPST/BJTaxi-Inflow.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "BJTaxi-InFlow" + mode : "train" + device : "cuda:0" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 1024 + steps_per_day: 48 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 100 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/BeijingAirQuality(Deprecated).yaml b/config/REPST/BeijingAirQuality(Deprecated).yaml new file mode 100755 index 0000000..595c971 --- /dev/null +++ b/config/REPST/BeijingAirQuality(Deprecated).yaml @@ -0,0 +1,61 @@ +basic: + dataset: "BeijingAirQuality" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 7 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 3 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 3 + output_dim: 3 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 3 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/METR-LA.yaml b/config/REPST/METR-LA.yaml new file mode 100755 index 0000000..1e3e29d --- /dev/null +++ b/config/REPST/METR-LA.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "METR-LA" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: true + add_time_in_day: true + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 207 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/config/REPST/SolarEnergy.yaml b/config/REPST/SolarEnergy.yaml new file mode 100755 index 0000000..282c929 --- /dev/null +++ b/config/REPST/SolarEnergy.yaml @@ -0,0 +1,60 @@ +basic: + dataset: "SolarEnergy" + mode : "train" + device : "cuda:1" + model: "REPST" + seed: 2023 + +data: + add_day_in_week: false + add_time_in_day: false + column_wise: false + days_per_week: 7 + default_graph: true + horizon: 24 + lag: 24 + normalizer: std + num_nodes: 137 + steps_per_day: 288 + test_ratio: 0.2 + tod: false + val_ratio: 0.2 + sample: 1 + input_dim: 1 + batch_size: 16 + +model: + pred_len: 24 + seq_len: 24 + patch_len: 6 + stride: 7 + dropout: 0.2 + gpt_layers: 9 + d_ff: 128 + gpt_path: ./GPT-2 + d_model: 64 + n_heads: 1 + input_dim: 1 + word_num: 1000 + +train: + batch_size: 16 + early_stop: true + early_stop_patience: 15 + epochs: 100 + grad_norm: false + loss_func: mae + lr_decay: true + lr_decay_rate: 0.3 + lr_decay_step: "5,20,40,70" + lr_init: 0.003 + max_grad_norm: 5 + real_value: true + weight_decay: 0 + debug: false + output_dim: 1 + log_step: 1000 + plot: false + mae_thresh: None + mape_thresh: 0.001 + diff --git a/dataloader/data_selector.py b/dataloader/data_selector.py index 19fe7f5..224b6fc 100644 --- a/dataloader/data_selector.py +++ b/dataloader/data_selector.py @@ -7,57 +7,58 @@ def load_st_dataset(config): # sample = config["data"]["sample"] # output B, N, D match dataset: + case "BeijingAirQuality": + data_path = os.path.join("./data/BeijingAirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 36000, 7, 3 + data = data.reshape(L, N, C) + case "AirQuality": + data_path = os.path.join("./data/AirQuality/data.dat") + data = np.memmap(data_path, dtype=np.float32, mode='r') + L, N, C = 8701,35,6 + data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: data = f['speed']['block0_values'][:] + case "METR-LA": + data_path = os.path.join("./data/METR-LA/METR-LA.h5") + with h5py.File(data_path, 'r') as f: + data = f['df']['block0_values'][:] + case "SolarEnergy": + data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv") + data = np.loadtxt(data_path, delimiter=",") case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD4": data_path = os.path.join("./data/PEMS04/PEMS04.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7": data_path = os.path.join("./data/PEMS07/PEMS07.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD8": data_path = os.path.join("./data/PEMS08/PEMS08.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(L)": data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") - data = np.load(data_path)["data"][ - :, :, 0 - ] + data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(M)": data_path = os.path.join("./data/PEMS07(M)/V_228.csv") - data = np.genfromtxt( - data_path, delimiter="," - ) - case "METR-LA": - data_path = os.path.join("./data/METR-LA/METR.h5") - with h5py.File( - data_path, "r" - ) as f: - data = np.array(f["data"]) + data = np.genfromtxt(data_path, delimiter=",") case "BJ": data_path = os.path.join("./data/BJ/BJ500.csv") - data = np.genfromtxt( - data_path, delimiter=",", skip_header=1 - ) + data = np.genfromtxt(data_path, delimiter=",", skip_header=1) case "Hainan": data_path = os.path.join("./data/Hainan/Hainan.npz") data = np.load(data_path)["data"][:, :, 0] case "SD": data_path = os.path.join("./data/SD/data.npz") data = np.load(data_path)["data"][:, :, 0].astype(np.float32) + case "BJTaxi-InFlow": + data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) + case "BJTaxi-OutFlow": + data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) case _: raise ValueError(f"Unsupported dataset: {dataset}") @@ -68,3 +69,16 @@ def load_st_dataset(config): print("加载 %s 数据集中... " % dataset) # return data[::sample] return data + +def read_BeijingTaxi(): + files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", + "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] + all_data = [] + for file in files: + data_path = os.path.join(f"./data/BeijingTaxi/{file}") + data = np.load(data_path) + all_data.append(data) + all_data = np.concatenate(all_data, axis=0) + time_num = all_data.shape[0] + all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) + return all_data \ No newline at end of file diff --git a/model/REPST/normalizer.py b/model/REPST/normalizer.py index fb7e182..c112c7a 100644 --- a/model/REPST/normalizer.py +++ b/model/REPST/normalizer.py @@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module): return self.gumbel_softmax(logits, 1, self.k, self.hard) def gumbel_softmax(self, logits, tau=1, k=1000, hard=True): - y_soft = F.gumbel_softmax(logits, tau, hard) - if hard: # 生成硬掩码 _, indices = y_soft.topk(k, dim=0) # 选择Top-K diff --git a/model/REPST/reprogramming.py b/model/REPST/reprogramming.py index 3806989..1ba7976 100644 --- a/model/REPST/reprogramming.py +++ b/model/REPST/reprogramming.py @@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module): return output class TokenEmbedding(nn.Module): - def __init__(self, c_in, d_model, patch_num, input_dim): + def __init__(self, c_in, d_model, patch_num, input_dim, output_dim): super(TokenEmbedding, self).__init__() padding = 1 self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model, kernel_size=3, padding=padding, padding_mode='circular', bias=False) - self.confusion_layer = nn.Linear(patch_num * input_dim, 1) + self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim) for m in self.modules(): if isinstance(m, nn.Conv1d): @@ -37,22 +37,20 @@ class TokenEmbedding(nn.Module): class PatchEmbedding(nn.Module): - def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim): + def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim): super(PatchEmbedding, self).__init__() # Patching self.patch_len = patch_len self.stride = stride self.padding_patch_layer = ReplicationPad1d((0, stride)) - self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim) + self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): - n_vars = x.shape[2] x = self.padding_patch_layer(x) x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride) x_value_embed = self.value_embedding(x) - return self.dropout(x_value_embed), n_vars class ReprogrammingLayer(nn.Module): @@ -84,13 +82,9 @@ class ReprogrammingLayer(nn.Module): def reprogramming(self, target_embedding, source_embedding, value_embedding): B, L, H, E = target_embedding.shape - scale = 1. / sqrt(E) - scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding) - A = self.dropout(torch.softmax(scale * scores, dim=-1)) reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding) - return reprogramming_embedding \ No newline at end of file diff --git a/model/REPST/repst.py b/model/REPST/repst.py index 53a6046..6ceeb2a 100644 --- a/model/REPST/repst.py +++ b/model/REPST/repst.py @@ -19,6 +19,7 @@ class repst(nn.Module): self.gpt_layers = configs['gpt_layers'] self.d_ff = configs['d_ff'] self.gpt_path = configs['gpt_path'] + self.output_dim = configs.get('output_dim', 1) self.word_choice = GumbelSoftmax(configs['word_num']) @@ -31,7 +32,7 @@ class repst(nn.Module): self.head_nf = self.d_ff * self.patch_nums # 词嵌入 - self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim) + self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim) # GPT2初始化 self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True) @@ -41,12 +42,12 @@ class repst(nn.Module): self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device) self.vocab_size = self.word_embeddings.shape[0] self.mapping_layer = nn.Linear(self.vocab_size, 1) - self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm) + self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm) self.out_mlp = nn.Sequential( nn.Linear(self.d_llm, 128), nn.ReLU(), - nn.Linear(128, self.pred_len) + nn.Linear(128, self.pred_len * self.output_dim) ) for i, (name, param) in enumerate(self.gpts.named_parameters()): @@ -62,7 +63,7 @@ class repst(nn.Module): torch.nn.init.zeros_(module.bias) def forward(self, x): - x = x[..., :1] + x = x[..., :self.input_dim] x_enc = rearrange(x, 'b t n c -> b n c t') enc_out, n_vars = self.patch_embedding(x_enc) self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0) @@ -72,32 +73,11 @@ class repst(nn.Module): enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings) enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state - dec_out = self.out_mlp(enc_out) - outputs = dec_out.unsqueeze(dim=-1) - outputs = outputs.repeat(1, 1, 1, n_vars) - outputs = outputs.permute(0,2,1,3) + dec_out = self.out_mlp(enc_out) #[B, N, T*C] + + B, N, _ = dec_out.shape + outputs = dec_out.view(B, N, self.pred_len, self.output_dim) + outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C return outputs -if __name__ == '__main__': - configs = { - 'device': 'cuda:0', - 'pred_len': 24, - 'seq_len': 24, - 'patch_len': 6, - 'stride': 7, - 'dropout': 0.2, - 'gpt_layers': 9, - 'd_ff': 128, - 'gpt_path': './GPT-2', - 'd_model': 64, - 'n_heads': 1, - 'input_dim': 1 - } - model = repst(configs) - x = torch.randn(16, 24, 325, 1) - y = model(x) - - print(y.shape) - - diff --git a/requirements.txt b/requirements.txt index d23694d..6199d59 100755 --- a/requirements.txt +++ b/requirements.txt @@ -11,4 +11,5 @@ fastdtw notebook torchcde einops -transformers \ No newline at end of file +transformers +py7zr \ No newline at end of file diff --git a/run.py b/run.py index 741f9d7..95867f4 100755 --- a/run.py +++ b/run.py @@ -14,6 +14,8 @@ def main(): args = parse_args() args = init.init_device(args) init.init_seed(args["basic"]["seed"]) + + # Load model model = init.init_model(args) # Load dataset diff --git a/trainer/Trainer.py b/trainer/Trainer.py index 626e9e9..508230e 100755 --- a/trainer/Trainer.py +++ b/trainer/Trainer.py @@ -203,7 +203,7 @@ class Trainer: self.stats.record_step_time(step_time, mode) # 累积损失和预测结果 - total_loss += d_loss.item() + total_loss += loss.item() y_pred.append(d_output.detach().cpu()) y_true.append(d_label.detach().cpu()) @@ -316,13 +316,9 @@ class Trainer: def _log_model_params(self): """输出模型可训练参数数量""" - try: - total_params = sum( - p.numel() for p in self.model.parameters() if p.requires_grad - ) - self.logger.info(f"Trainable params: {total_params}") - except Exception: - pass + total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad) + self.logger.info(f"Trainable params: {total_params}") + def _finalize_training(self, best_model, best_test_model): self.model.load_state_dict(best_model) @@ -353,35 +349,26 @@ class Trainer: for data, target in data_loader: label = target[..., : args["output_dim"]] output = model(data) - y_pred.append(output) - y_true.append(label) + y_pred.append(output.detach().cpu()) + y_true.append(label.detach().cpu()) - # 合并所有批次的预测结果 - if args["real_value"]: - y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) - else: - y_pred = torch.cat(y_pred, dim=0) - y_true = torch.cat(y_true, dim=0) + + d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0)) + d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0)) # 计算并记录每个时间步的指标 - for t in range(y_true.shape[1]): + for t in range(d_y_true.shape[1]): mae, rmse, mape = all_metrics( - y_pred[:, t, ...], - y_true[:, t, ...], + d_y_pred[:, t, ...], + d_y_true[:, t, ...], args["mae_thresh"], args["mape_thresh"], ) - logger.info( - f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") # 计算并记录平均指标 - mae, rmse, mape = all_metrics( - y_pred, y_true, args["mae_thresh"], args["mape_thresh"] - ) - logger.info( - f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}" - ) + mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"]) + logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}") @staticmethod def _compute_sampling_threshold(global_step, k): diff --git a/utils/Download_data.py b/utils/Download_data.py index fcd21f1..539a804 100755 --- a/utils/Download_data.py +++ b/utils/Download_data.py @@ -1,204 +1,191 @@ -import os -import requests -import zipfile -import shutil -import kagglehub # 假设 kagglehub 是一个可用的库 +import os, json, shutil, requests +from urllib.parse import urlsplit from tqdm import tqdm - -# 定义文件完整性信息的字典 +import kagglehub +import py7zr -def check_and_download_data(): - """ - 检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。 - """ - current_working_dir = os.getcwd() # 获取当前工作目录 - data_dir = os.path.join( - current_working_dir, "data" - ) # 假设 data 文件夹在当前工作目录下 - - expected_structure = { - "PEMS03": [ - "PEMS03.csv", - "PEMS03.npz", - "PEMS03.txt", - "PEMS03_dtw_distance.npy", - "PEMS03_spatial_distance.npy", - ], - "PEMS04": [ - "PEMS04.csv", - "PEMS04.npz", - "PEMS04_dtw_distance.npy", - "PEMS04_spatial_distance.npy", - ], - "PEMS07": [ - "PEMS07.csv", - "PEMS07.npz", - "PEMS07_dtw_distance.npy", - "PEMS07_spatial_distance.npy", - ], - "PEMS08": [ - "PEMS08.csv", - "PEMS08.npz", - "PEMS08_dtw_distance.npy", - "PEMS08_spatial_distance.npy", - ], - "PEMS-BAY": [ - "adj_mx_bay.pkl", - "pems-bay-meta.h5", - "pems-bay.h5" - ] - } - - current_dir = os.getcwd() # 获取当前工作目录 - missing_adj = False - missing_main_files = False - - # 检查 data 文件夹是否存在 - if not os.path.exists(data_dir) or not os.path.isdir(data_dir): - # print(f"目录 {data_dir} 不存在。") - print("正在下载所有必要的数据文件...") - missing_adj = True - missing_main_files = True - else: - # 遍历预期的文件结构 - for subfolder, expected_files in expected_structure.items(): - subfolder_path = os.path.join(data_dir, subfolder) - - # 检查子文件夹是否存在 - if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path): - # print(f"子文件夹 {subfolder} 不存在。") - missing_main_files = True - continue - - # 获取子文件夹中的实际文件列表 - actual_files = os.listdir(subfolder_path) - - # 检查是否缺少文件 - for expected_file in expected_files: - if expected_file not in actual_files: - # print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。") - if ( - "_dtw_distance.npy" in expected_file - or "_spatial_distance.npy" in expected_file - ): - missing_adj = True - else: - missing_main_files = True - - # 根据缺失文件类型调用下载逻辑 - if missing_adj: - download_adj_data(current_dir) - if missing_main_files: - download_kaggle_data(current_dir, 'elmahy/pems-dataset') - download_kaggle_data(current_dir, 'scchuy/pemsbay') +# ---------- 1. 检测完整性 ---------- +def detect_data_integrity(data_dir, expected): + missing_list = [] + if not os.path.isdir(data_dir): + # 如果数据目录不存在,则所有数据集都缺失 + missing_list.extend(expected.keys()) + # 标记adj也缺失 + missing_list.append("adj") + return missing_list - rearrange_dir() + # 检查adj相关文件(距离矩阵文件) + has_missing_adj = False + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if os.path.isdir(folder_path): + existing = set(os.listdir(folder_path)) + for f in files: + if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_adj = True + break + if has_missing_adj: + missing_list.append("adj") + + # 检查数据集主文件 + for folder, files in expected.items(): + folder_path = os.path.join(data_dir, folder) + if not os.path.isdir(folder_path): + missing_list.append(folder) + continue + + existing = set(os.listdir(folder_path)) + has_missing_file = False + + for f in files: + # 跳过距离矩阵文件,已经在上面检查过了 + if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing: + has_missing_file = True + + if has_missing_file and folder not in missing_list: + missing_list.append(folder) + + # print(f"缺失数据集:{missing_list}") + return missing_list +# ---------- 2. 下载 7z 并解压 ---------- +def download_and_extract(url, dst_dir, max_retries=3): + os.makedirs(dst_dir, exist_ok=True) + filename = os.path.basename(urlsplit(url).path) or "download.7z" + file_path = os.path.join(dst_dir, filename) + for attempt in range(1, max_retries+1): + try: + # 下载 + with requests.get(url, stream=True, timeout=30) as r: + r.raise_for_status() + total = int(r.headers.get("content-length",0)) + with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar: + for chunk in r.iter_content(8192): + f.write(chunk) + bar.update(len(chunk)) + # 解压 7z + with py7zr.SevenZipFile(file_path, mode='r') as archive: + archive.extractall(path=dst_dir) + os.remove(file_path) + return + except Exception as e: + if attempt==max_retries: raise RuntimeError("下载或解压失败") + print("错误,重试中...", e) + +# ---------- 3. 下载 Kaggle 数据 ---------- +def download_kaggle_data(base_dir, dataset): + try: + print(f"Downloading kaggle dataset : {dataset}") + path = kagglehub.dataset_download(dataset) + shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True) + except Exception as e: + print("Kaggle 下载失败:", dataset, e) + +# ---------- 4. 下载 GitHub 数据 ---------- +def download_github_data(file_path, save_dir): + if not os.path.exists(save_dir): + os.makedirs(save_dir) + raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + # raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}" + response = requests.head(raw_url, allow_redirects=True) + if response.status_code != 200: + print(f"Failed to get file size for {raw_url}. Status code:", response.status_code) + return + + file_size = int(response.headers.get('Content-Length', 0)) + response = requests.get(raw_url, stream=True, allow_redirects=True) + file_name = os.path.basename(file_path) + file_path_to_save = os.path.join(save_dir, file_name) + with open(file_path_to_save, 'wb') as f: + with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar: + for chunk in response.iter_content(chunk_size=1024): + if chunk: + f.write(chunk) + pbar.update(len(chunk)) + +# ---------- 5. 整理目录 ---------- +def rearrange_dir(): + data_dir = os.path.join(os.getcwd(), "data") + nested = os.path.join(data_dir,"data") + if os.path.isdir(nested): + for item in os.listdir(nested): + src,dst = os.path.join(nested,item), os.path.join(data_dir,item) + if os.path.isdir(src): + shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录 + else: + shutil.copy2(src, dst) + shutil.rmtree(nested) + + for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]: + dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True) + for f in os.listdir(data_dir): + if kw in f.lower() and f.endswith((".h5",".pkl")): + shutil.move(os.path.join(data_dir,f), os.path.join(dst,f)) + + solar = os.path.join(data_dir,"solar-energy") + if os.path.isdir(solar): + dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True) + csv = os.path.join(solar,"solar_AL.csv") + if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv")) + shutil.rmtree(solar) + +# ---------- 6. 主流程 ---------- +def check_and_download_data(): + # 加载结构文件,检测缺失数据集 + cwd = os.getcwd() + data_dir = os.path.join(cwd,"data") + with open("utils/dataset.json", "r", encoding="utf-8") as f: + file_tree = json.load(f) + missing_list = detect_data_integrity(data_dir, file_tree) + # print(f"缺失数据集:{missing_list}") + + # 检查并下载adj数据 + if "adj" in missing_list: + download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir) + # 下载后从缺失列表中移除adj + missing_list.remove("adj") + + # 检查BeijingAirQuality和AirQuality + if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list: + download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) + + # 检查并下载TaxiBJ数据 + if "TaxiBJ" in missing_list: + taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi") + taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy'] + for file in taxibj_files: + file_path = f"Datasets/TaxiBJ/{file}" + download_github_data(file_path, taxi_bj_floder) + # 下载后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) + + # 检查并下载pems, bay, metr-la, solar-energy数据 + kaggle_map = { + "PEMS03": "elmahy/pems-dataset", + "PEMS04": "elmahy/pems-dataset", + "PEMS07": "elmahy/pems-dataset", + "PEMS08": "elmahy/pems-dataset", + "PEMS-BAY": "scchuy/pemsbay", + "METR-LA": "annnnguyen/metr-la-dataset", + "SolarEnergy": "wangshaoqi/solar-energy" + } + + # 先对kaggle下载地址进行去重,避免重复下载相同的数据集 + downloaded_kaggle_datasets = set() + + for dataset, kaggle_ds in kaggle_map.items(): + if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets: + download_kaggle_data(cwd, kaggle_ds) + # 将已下载的数据集添加到集合中 + downloaded_kaggle_datasets.add(kaggle_ds) + # 下载一个数据集后更新缺失列表 + missing_list = detect_data_integrity(data_dir, file_tree) + + rearrange_dir() return True - -def download_adj_data(current_dir, max_retries=3): - """ - 下载并解压 adj.zip 文件,并显示下载进度条。 - 如果下载失败,最多重试 max_retries 次。 - """ - url = "http://code.zhang-heng.com/static/adj.zip" - retries = 0 - - while retries <= max_retries: - try: - print(f"正在从 {url} 下载邻接矩阵文件...") - response = requests.get(url, stream=True) - - if response.status_code == 200: - total_size = int(response.headers.get("content-length", 0)) - block_size = 1024 # 1KB - t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度") - - zip_file_path = os.path.join(current_dir, "adj.zip") - with open(zip_file_path, "wb") as f: - for data in response.iter_content(block_size): - f.write(data) - t.update(len(data)) - t.close() - - # print("下载完成,文件已保存到:", zip_file_path) - - if os.path.exists(zip_file_path): - with zipfile.ZipFile(zip_file_path, "r") as zip_ref: - zip_ref.extractall(current_dir) - # print("数据集已解压到:", current_dir) - os.remove(zip_file_path) # 删除zip文件 - else: - print("未找到下载的zip文件,跳过解压。") - break # 下载成功,退出循环 - else: - print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。") - except Exception as e: - print(f"下载或解压数据集时出错: {e}") - print("如果链接无效,请检查URL的合法性或稍后重试。") - - retries += 1 - if retries > max_retries: - raise Exception( - f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。" - ) - - -def download_kaggle_data(current_dir, kaggle_path): - """ - 下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。 - 如果目标文件夹已存在,会覆盖冲突的文件。 - """ - try: - print(f"正在下载 {kaggle_path} 数据集...") - path = kagglehub.dataset_download(kaggle_path) - # print("Path to KaggleHub dataset files:", path) - - if os.path.exists(path): - destination_path = os.path.join(current_dir, "data") - # 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件 - shutil.copytree(path, destination_path, dirs_exist_ok=True) - except Exception as e: - print(f"下载或处理 KaggleHub 数据集时出错: {e}") - - - -def rearrange_dir(): - """ - 将 data/data 中的文件合并到上级目录,并删除 data/data 目录。 - """ - data_dir = os.path.join(os.getcwd(), "data") - nested_data_dir = os.path.join(data_dir, "data") - - if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir): - for item in os.listdir(nested_data_dir): - source_path = os.path.join(nested_data_dir, item) - destination_path = os.path.join(data_dir, item) - - if os.path.isdir(source_path): - shutil.copytree(source_path, destination_path, dirs_exist_ok=True) - else: - shutil.copy2(source_path, destination_path) - - shutil.rmtree(nested_data_dir) - # print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。") - - # 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹 - pems_bay_dir = os.path.join(data_dir, "PEMS-BAY") - os.makedirs(pems_bay_dir, exist_ok=True) - - for item in os.listdir(data_dir): - if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")): - source_path = os.path.join(data_dir, item) - destination_path = os.path.join(pems_bay_dir, item) - shutil.move(source_path, destination_path) - - # print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。") - - -# 主程序 -if __name__ == "__main__": +if __name__=="__main__": check_and_download_data() - # rearrange_dir() diff --git a/utils/dataset.json b/utils/dataset.json new file mode 100644 index 0000000..55f46c3 --- /dev/null +++ b/utils/dataset.json @@ -0,0 +1,41 @@ +{ + "PEMS03": [ + "PEMS03.csv", + "PEMS03.npz", + "PEMS03.txt", + "PEMS03_dtw_distance.npy", + "PEMS03_spatial_distance.npy" + ], + "PEMS04": [ + "PEMS04.csv", + "PEMS04.npz", + "PEMS04_dtw_distance.npy", + "PEMS04_spatial_distance.npy" + ], + "PEMS07": [ + "PEMS07.csv", + "PEMS07.npz", + "PEMS07_dtw_distance.npy", + "PEMS07_spatial_distance.npy" + ], + "PEMS08": [ + "PEMS08.csv", + "PEMS08.npz", + "PEMS08_dtw_distance.npy", + "PEMS08_spatial_distance.npy" + ], + "PEMS-BAY": [ + "adj_mx_bay.pkl", + "pems-bay-meta.h5", + "pems-bay.h5" + ], + "METR-LA": [ + "METR-LA.h5" + ], + "SolarEnergy": [ + "SolarEnergy.csv" + ], + "BeijingAirQuality": ["data.dat", "desc.json"], + "AirQuality": ["data.dat"], + "BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] +}