This commit is contained in:
czzhangheng 2025-11-23 19:10:51 +08:00
commit 5cd81f4d4c
16 changed files with 650 additions and 298 deletions

42
.vscode/launch.json vendored
View File

@ -4,7 +4,7 @@
// 访: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
{
"name": "STID_PEMS-BAY",
"type": "debugpy",
"request": "launch",
@ -28,6 +28,14 @@
"console": "integratedTerminal",
"args": "--config ./config/REPST/PEMSD8.yaml"
},
{
"name": "REPST-BJTaxi-InFlow",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/BJTaxi-Inflow.yaml"
},
{
"name": "REPST-PEMSBAY",
"type": "debugpy",
@ -36,6 +44,38 @@
"console": "integratedTerminal",
"args": "--config ./config/REPST/PEMS-BAY.yaml"
},
{
"name": "REPST-METR",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/METR-LA.yaml"
},
{
"name": "REPST-Solar",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/SolarEnergy.yaml"
},
{
"name": "BeijingAirQuality",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
},
{
"name": "AirQuality",
"type": "debugpy",
"request": "launch",
"program": "run.py",
"console": "integratedTerminal",
"args": "--config ./config/REPST/AirQuality.yaml"
},
{
"name": "AEPSA-PEMSBAY",
"type": "debugpy",

12
.vscode/settings.json vendored
View File

@ -1,5 +1,11 @@
{
"python-envs.defaultEnvManager": "ms-python.python:system",
"python-envs.defaultPackageManager": "ms-python.python:pip",
"python-envs.pythonProjects": []
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": [
{
"path": "data/SolarEnergy",
"envManager": "ms-python.python:conda",
"packageManager": "ms-python.python:conda"
}
]
}

61
config/REPST/AirQuality.yaml Executable file
View File

@ -0,0 +1,61 @@
basic:
dataset: "AirQuality"
mode : "train"
device : "cuda:1"
model: "REPST"
seed: 2023
data:
add_day_in_week: false
add_time_in_day: false
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 35
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 6
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 6
output_dim: 3
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 3
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

60
config/REPST/BJTaxi-Inflow.yaml Executable file
View File

@ -0,0 +1,60 @@
basic:
dataset: "BJTaxi-InFlow"
mode : "train"
device : "cuda:0"
model: "REPST"
seed: 2023
data:
add_day_in_week: false
add_time_in_day: false
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 1024
steps_per_day: 48
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 1
log_step: 100
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -0,0 +1,61 @@
basic:
dataset: "BeijingAirQuality"
mode : "train"
device : "cuda:1"
model: "REPST"
seed: 2023
data:
add_day_in_week: false
add_time_in_day: false
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 7
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 3
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 3
output_dim: 3
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 3
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

60
config/REPST/METR-LA.yaml Executable file
View File

@ -0,0 +1,60 @@
basic:
dataset: "METR-LA"
mode : "train"
device : "cuda:1"
model: "REPST"
seed: 2023
data:
add_day_in_week: true
add_time_in_day: true
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 207
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 1
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

60
config/REPST/SolarEnergy.yaml Executable file
View File

@ -0,0 +1,60 @@
basic:
dataset: "SolarEnergy"
mode : "train"
device : "cuda:1"
model: "REPST"
seed: 2023
data:
add_day_in_week: false
add_time_in_day: false
column_wise: false
days_per_week: 7
default_graph: true
horizon: 24
lag: 24
normalizer: std
num_nodes: 137
steps_per_day: 288
test_ratio: 0.2
tod: false
val_ratio: 0.2
sample: 1
input_dim: 1
batch_size: 16
model:
pred_len: 24
seq_len: 24
patch_len: 6
stride: 7
dropout: 0.2
gpt_layers: 9
d_ff: 128
gpt_path: ./GPT-2
d_model: 64
n_heads: 1
input_dim: 1
word_num: 1000
train:
batch_size: 16
early_stop: true
early_stop_patience: 15
epochs: 100
grad_norm: false
loss_func: mae
lr_decay: true
lr_decay_rate: 0.3
lr_decay_step: "5,20,40,70"
lr_init: 0.003
max_grad_norm: 5
real_value: true
weight_decay: 0
debug: false
output_dim: 1
log_step: 1000
plot: false
mae_thresh: None
mape_thresh: 0.001

View File

@ -7,57 +7,58 @@ def load_st_dataset(config):
# sample = config["data"]["sample"]
# output B, N, D
match dataset:
case "BeijingAirQuality":
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 36000, 7, 3
data = data.reshape(L, N, C)
case "AirQuality":
data_path = os.path.join("./data/AirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 8701,35,6
data = data.reshape(L, N, C)
case "PEMS-BAY":
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
with h5py.File(data_path, 'r') as f:
data = f['speed']['block0_values'][:]
case "METR-LA":
data_path = os.path.join("./data/METR-LA/METR-LA.h5")
with h5py.File(data_path, 'r') as f:
data = f['df']['block0_values'][:]
case "SolarEnergy":
data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv")
data = np.loadtxt(data_path, delimiter=",")
case "PEMSD3":
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
data = np.load(data_path)["data"][
:, :, 0
]
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD4":
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
data = np.load(data_path)["data"][
:, :, 0
]
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7":
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
data = np.load(data_path)["data"][
:, :, 0
]
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD8":
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
data = np.load(data_path)["data"][
:, :, 0
]
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(L)":
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
data = np.load(data_path)["data"][
:, :, 0
]
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(M)":
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
data = np.genfromtxt(
data_path, delimiter=","
)
case "METR-LA":
data_path = os.path.join("./data/METR-LA/METR.h5")
with h5py.File(
data_path, "r"
) as f:
data = np.array(f["data"])
data = np.genfromtxt(data_path, delimiter=",")
case "BJ":
data_path = os.path.join("./data/BJ/BJ500.csv")
data = np.genfromtxt(
data_path, delimiter=",", skip_header=1
)
data = np.genfromtxt(data_path, delimiter=",", skip_header=1)
case "Hainan":
data_path = os.path.join("./data/Hainan/Hainan.npz")
data = np.load(data_path)["data"][:, :, 0]
case "SD":
data_path = os.path.join("./data/SD/data.npz")
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case "BJTaxi-InFlow":
data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32)
case "BJTaxi-OutFlow":
data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32)
case _:
raise ValueError(f"Unsupported dataset: {dataset}")
@ -68,3 +69,16 @@ def load_st_dataset(config):
print("加载 %s 数据集中... " % dataset)
# return data[::sample]
return data
def read_BeijingTaxi():
files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
all_data = []
for file in files:
data_path = os.path.join(f"./data/BeijingTaxi/{file}")
data = np.load(data_path)
all_data.append(data)
all_data = np.concatenate(all_data, axis=0)
time_num = all_data.shape[0]
all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2)
return all_data

View File

@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
return self.gumbel_softmax(logits, 1, self.k, self.hard)
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
y_soft = F.gumbel_softmax(logits, tau, hard)
if hard:
# 生成硬掩码
_, indices = y_soft.topk(k, dim=0) # 选择Top-K

View File

@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module):
return output
class TokenEmbedding(nn.Module):
def __init__(self, c_in, d_model, patch_num, input_dim):
def __init__(self, c_in, d_model, patch_num, input_dim, output_dim):
super(TokenEmbedding, self).__init__()
padding = 1
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim)
for m in self.modules():
if isinstance(m, nn.Conv1d):
@ -37,22 +37,20 @@ class TokenEmbedding(nn.Module):
class PatchEmbedding(nn.Module):
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim):
super(PatchEmbedding, self).__init__()
# Patching
self.patch_len = patch_len
self.stride = stride
self.padding_patch_layer = ReplicationPad1d((0, stride))
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
n_vars = x.shape[2]
x = self.padding_patch_layer(x)
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
x_value_embed = self.value_embedding(x)
return self.dropout(x_value_embed), n_vars
class ReprogrammingLayer(nn.Module):
@ -84,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
def reprogramming(self, target_embedding, source_embedding, value_embedding):
B, L, H, E = target_embedding.shape
scale = 1. / sqrt(E)
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
A = self.dropout(torch.softmax(scale * scores, dim=-1))
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
return reprogramming_embedding

View File

@ -19,6 +19,7 @@ class repst(nn.Module):
self.gpt_layers = configs['gpt_layers']
self.d_ff = configs['d_ff']
self.gpt_path = configs['gpt_path']
self.output_dim = configs.get('output_dim', 1)
self.word_choice = GumbelSoftmax(configs['word_num'])
@ -31,7 +32,7 @@ class repst(nn.Module):
self.head_nf = self.d_ff * self.patch_nums
# 词嵌入
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
# GPT2初始化
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
@ -41,12 +42,12 @@ class repst(nn.Module):
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
self.vocab_size = self.word_embeddings.shape[0]
self.mapping_layer = nn.Linear(self.vocab_size, 1)
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm)
self.out_mlp = nn.Sequential(
nn.Linear(self.d_llm, 128),
nn.ReLU(),
nn.Linear(128, self.pred_len)
nn.Linear(128, self.pred_len * self.output_dim)
)
for i, (name, param) in enumerate(self.gpts.named_parameters()):
@ -62,7 +63,7 @@ class repst(nn.Module):
torch.nn.init.zeros_(module.bias)
def forward(self, x):
x = x[..., :1]
x = x[..., :self.input_dim]
x_enc = rearrange(x, 'b t n c -> b n c t')
enc_out, n_vars = self.patch_embedding(x_enc)
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
@ -72,32 +73,11 @@ class repst(nn.Module):
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
dec_out = self.out_mlp(enc_out)
outputs = dec_out.unsqueeze(dim=-1)
outputs = outputs.repeat(1, 1, 1, n_vars)
outputs = outputs.permute(0,2,1,3)
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
B, N, _ = dec_out.shape
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
return outputs
if __name__ == '__main__':
configs = {
'device': 'cuda:0',
'pred_len': 24,
'seq_len': 24,
'patch_len': 6,
'stride': 7,
'dropout': 0.2,
'gpt_layers': 9,
'd_ff': 128,
'gpt_path': './GPT-2',
'd_model': 64,
'n_heads': 1,
'input_dim': 1
}
model = repst(configs)
x = torch.randn(16, 24, 325, 1)
y = model(x)
print(y.shape)

View File

@ -11,4 +11,5 @@ fastdtw
notebook
torchcde
einops
transformers
transformers
py7zr

2
run.py
View File

@ -14,6 +14,8 @@ def main():
args = parse_args()
args = init.init_device(args)
init.init_seed(args["basic"]["seed"])
# Load model
model = init.init_model(args)
# Load dataset

View File

@ -203,7 +203,7 @@ class Trainer:
self.stats.record_step_time(step_time, mode)
# 累积损失和预测结果
total_loss += d_loss.item()
total_loss += loss.item()
y_pred.append(d_output.detach().cpu())
y_true.append(d_label.detach().cpu())
@ -316,13 +316,9 @@ class Trainer:
def _log_model_params(self):
"""输出模型可训练参数数量"""
try:
total_params = sum(
p.numel() for p in self.model.parameters() if p.requires_grad
)
self.logger.info(f"Trainable params: {total_params}")
except Exception:
pass
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
self.logger.info(f"Trainable params: {total_params}")
def _finalize_training(self, best_model, best_test_model):
self.model.load_state_dict(best_model)
@ -353,35 +349,26 @@ class Trainer:
for data, target in data_loader:
label = target[..., : args["output_dim"]]
output = model(data)
y_pred.append(output)
y_true.append(label)
y_pred.append(output.detach().cpu())
y_true.append(label.detach().cpu())
# 合并所有批次的预测结果
if args["real_value"]:
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
else:
y_pred = torch.cat(y_pred, dim=0)
y_true = torch.cat(y_true, dim=0)
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
# 计算并记录每个时间步的指标
for t in range(y_true.shape[1]):
for t in range(d_y_true.shape[1]):
mae, rmse, mape = all_metrics(
y_pred[:, t, ...],
y_true[:, t, ...],
d_y_pred[:, t, ...],
d_y_true[:, t, ...],
args["mae_thresh"],
args["mape_thresh"],
)
logger.info(
f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
)
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
# 计算并记录平均指标
mae, rmse, mape = all_metrics(
y_pred, y_true, args["mae_thresh"], args["mape_thresh"]
)
logger.info(
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
)
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"])
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
@staticmethod
def _compute_sampling_threshold(global_step, k):

View File

@ -1,204 +1,191 @@
import os
import requests
import zipfile
import shutil
import kagglehub # 假设 kagglehub 是一个可用的库
import os, json, shutil, requests
from urllib.parse import urlsplit
from tqdm import tqdm
# 定义文件完整性信息的字典
import kagglehub
import py7zr
def check_and_download_data():
"""
检查 data 文件夹的完整性并根据缺失文件类型下载相应数据
"""
current_working_dir = os.getcwd() # 获取当前工作目录
data_dir = os.path.join(
current_working_dir, "data"
) # 假设 data 文件夹在当前工作目录下
expected_structure = {
"PEMS03": [
"PEMS03.csv",
"PEMS03.npz",
"PEMS03.txt",
"PEMS03_dtw_distance.npy",
"PEMS03_spatial_distance.npy",
],
"PEMS04": [
"PEMS04.csv",
"PEMS04.npz",
"PEMS04_dtw_distance.npy",
"PEMS04_spatial_distance.npy",
],
"PEMS07": [
"PEMS07.csv",
"PEMS07.npz",
"PEMS07_dtw_distance.npy",
"PEMS07_spatial_distance.npy",
],
"PEMS08": [
"PEMS08.csv",
"PEMS08.npz",
"PEMS08_dtw_distance.npy",
"PEMS08_spatial_distance.npy",
],
"PEMS-BAY": [
"adj_mx_bay.pkl",
"pems-bay-meta.h5",
"pems-bay.h5"
]
}
current_dir = os.getcwd() # 获取当前工作目录
missing_adj = False
missing_main_files = False
# 检查 data 文件夹是否存在
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
# print(f"目录 {data_dir} 不存在。")
print("正在下载所有必要的数据文件...")
missing_adj = True
missing_main_files = True
else:
# 遍历预期的文件结构
for subfolder, expected_files in expected_structure.items():
subfolder_path = os.path.join(data_dir, subfolder)
# 检查子文件夹是否存在
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
# print(f"子文件夹 {subfolder} 不存在。")
missing_main_files = True
continue
# 获取子文件夹中的实际文件列表
actual_files = os.listdir(subfolder_path)
# 检查是否缺少文件
for expected_file in expected_files:
if expected_file not in actual_files:
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
if (
"_dtw_distance.npy" in expected_file
or "_spatial_distance.npy" in expected_file
):
missing_adj = True
else:
missing_main_files = True
# 根据缺失文件类型调用下载逻辑
if missing_adj:
download_adj_data(current_dir)
if missing_main_files:
download_kaggle_data(current_dir, 'elmahy/pems-dataset')
download_kaggle_data(current_dir, 'scchuy/pemsbay')
# ---------- 1. 检测完整性 ----------
def detect_data_integrity(data_dir, expected):
missing_list = []
if not os.path.isdir(data_dir):
# 如果数据目录不存在,则所有数据集都缺失
missing_list.extend(expected.keys())
# 标记adj也缺失
missing_list.append("adj")
return missing_list
rearrange_dir()
# 检查adj相关文件距离矩阵文件
has_missing_adj = False
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if os.path.isdir(folder_path):
existing = set(os.listdir(folder_path))
for f in files:
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
has_missing_adj = True
break
if has_missing_adj:
missing_list.append("adj")
# 检查数据集主文件
for folder, files in expected.items():
folder_path = os.path.join(data_dir, folder)
if not os.path.isdir(folder_path):
missing_list.append(folder)
continue
existing = set(os.listdir(folder_path))
has_missing_file = False
for f in files:
# 跳过距离矩阵文件,已经在上面检查过了
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
has_missing_file = True
if has_missing_file and folder not in missing_list:
missing_list.append(folder)
# print(f"缺失数据集:{missing_list}")
return missing_list
# ---------- 2. 下载 7z 并解压 ----------
def download_and_extract(url, dst_dir, max_retries=3):
os.makedirs(dst_dir, exist_ok=True)
filename = os.path.basename(urlsplit(url).path) or "download.7z"
file_path = os.path.join(dst_dir, filename)
for attempt in range(1, max_retries+1):
try:
# 下载
with requests.get(url, stream=True, timeout=30) as r:
r.raise_for_status()
total = int(r.headers.get("content-length",0))
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
for chunk in r.iter_content(8192):
f.write(chunk)
bar.update(len(chunk))
# 解压 7z
with py7zr.SevenZipFile(file_path, mode='r') as archive:
archive.extractall(path=dst_dir)
os.remove(file_path)
return
except Exception as e:
if attempt==max_retries: raise RuntimeError("下载或解压失败")
print("错误,重试中...", e)
# ---------- 3. 下载 Kaggle 数据 ----------
def download_kaggle_data(base_dir, dataset):
try:
print(f"Downloading kaggle dataset : {dataset}")
path = kagglehub.dataset_download(dataset)
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
except Exception as e:
print("Kaggle 下载失败:", dataset, e)
# ---------- 4. 下载 GitHub 数据 ----------
def download_github_data(file_path, save_dir):
if not os.path.exists(save_dir):
os.makedirs(save_dir)
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
response = requests.head(raw_url, allow_redirects=True)
if response.status_code != 200:
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
return
file_size = int(response.headers.get('Content-Length', 0))
response = requests.get(raw_url, stream=True, allow_redirects=True)
file_name = os.path.basename(file_path)
file_path_to_save = os.path.join(save_dir, file_name)
with open(file_path_to_save, 'wb') as f:
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
pbar.update(len(chunk))
# ---------- 5. 整理目录 ----------
def rearrange_dir():
data_dir = os.path.join(os.getcwd(), "data")
nested = os.path.join(data_dir,"data")
if os.path.isdir(nested):
for item in os.listdir(nested):
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
if os.path.isdir(src):
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
else:
shutil.copy2(src, dst)
shutil.rmtree(nested)
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
for f in os.listdir(data_dir):
if kw in f.lower() and f.endswith((".h5",".pkl")):
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
solar = os.path.join(data_dir,"solar-energy")
if os.path.isdir(solar):
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
csv = os.path.join(solar,"solar_AL.csv")
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
shutil.rmtree(solar)
# ---------- 6. 主流程 ----------
def check_and_download_data():
# 加载结构文件,检测缺失数据集
cwd = os.getcwd()
data_dir = os.path.join(cwd,"data")
with open("utils/dataset.json", "r", encoding="utf-8") as f:
file_tree = json.load(f)
missing_list = detect_data_integrity(data_dir, file_tree)
# print(f"缺失数据集:{missing_list}")
# 检查并下载adj数据
if "adj" in missing_list:
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
# 下载后从缺失列表中移除adj
missing_list.remove("adj")
# 检查BeijingAirQuality和AirQuality
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载TaxiBJ数据
if "TaxiBJ" in missing_list:
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
for file in taxibj_files:
file_path = f"Datasets/TaxiBJ/{file}"
download_github_data(file_path, taxi_bj_floder)
# 下载后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
# 检查并下载pems, bay, metr-la, solar-energy数据
kaggle_map = {
"PEMS03": "elmahy/pems-dataset",
"PEMS04": "elmahy/pems-dataset",
"PEMS07": "elmahy/pems-dataset",
"PEMS08": "elmahy/pems-dataset",
"PEMS-BAY": "scchuy/pemsbay",
"METR-LA": "annnnguyen/metr-la-dataset",
"SolarEnergy": "wangshaoqi/solar-energy"
}
# 先对kaggle下载地址进行去重避免重复下载相同的数据集
downloaded_kaggle_datasets = set()
for dataset, kaggle_ds in kaggle_map.items():
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
download_kaggle_data(cwd, kaggle_ds)
# 将已下载的数据集添加到集合中
downloaded_kaggle_datasets.add(kaggle_ds)
# 下载一个数据集后更新缺失列表
missing_list = detect_data_integrity(data_dir, file_tree)
rearrange_dir()
return True
def download_adj_data(current_dir, max_retries=3):
"""
下载并解压 adj.zip 文件并显示下载进度条
如果下载失败最多重试 max_retries
"""
url = "http://code.zhang-heng.com/static/adj.zip"
retries = 0
while retries <= max_retries:
try:
print(f"正在从 {url} 下载邻接矩阵文件...")
response = requests.get(url, stream=True)
if response.status_code == 200:
total_size = int(response.headers.get("content-length", 0))
block_size = 1024 # 1KB
t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度")
zip_file_path = os.path.join(current_dir, "adj.zip")
with open(zip_file_path, "wb") as f:
for data in response.iter_content(block_size):
f.write(data)
t.update(len(data))
t.close()
# print("下载完成,文件已保存到:", zip_file_path)
if os.path.exists(zip_file_path):
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
zip_ref.extractall(current_dir)
# print("数据集已解压到:", current_dir)
os.remove(zip_file_path) # 删除zip文件
else:
print("未找到下载的zip文件跳过解压。")
break # 下载成功,退出循环
else:
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
except Exception as e:
print(f"下载或解压数据集时出错: {e}")
print("如果链接无效请检查URL的合法性或稍后重试。")
retries += 1
if retries > max_retries:
raise Exception(
f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。"
)
def download_kaggle_data(current_dir, kaggle_path):
"""
下载 KaggleHub 数据集并将数据直接移动到当前工作目录的 data 文件夹
如果目标文件夹已存在会覆盖冲突的文件
"""
try:
print(f"正在下载 {kaggle_path} 数据集...")
path = kagglehub.dataset_download(kaggle_path)
# print("Path to KaggleHub dataset files:", path)
if os.path.exists(path):
destination_path = os.path.join(current_dir, "data")
# 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件
shutil.copytree(path, destination_path, dirs_exist_ok=True)
except Exception as e:
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
def rearrange_dir():
"""
data/data 中的文件合并到上级目录并删除 data/data 目录
"""
data_dir = os.path.join(os.getcwd(), "data")
nested_data_dir = os.path.join(data_dir, "data")
if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir):
for item in os.listdir(nested_data_dir):
source_path = os.path.join(nested_data_dir, item)
destination_path = os.path.join(data_dir, item)
if os.path.isdir(source_path):
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
else:
shutil.copy2(source_path, destination_path)
shutil.rmtree(nested_data_dir)
# print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。")
# 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹
pems_bay_dir = os.path.join(data_dir, "PEMS-BAY")
os.makedirs(pems_bay_dir, exist_ok=True)
for item in os.listdir(data_dir):
if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
source_path = os.path.join(data_dir, item)
destination_path = os.path.join(pems_bay_dir, item)
shutil.move(source_path, destination_path)
# print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。")
# 主程序
if __name__ == "__main__":
if __name__=="__main__":
check_and_download_data()
# rearrange_dir()

41
utils/dataset.json Normal file
View File

@ -0,0 +1,41 @@
{
"PEMS03": [
"PEMS03.csv",
"PEMS03.npz",
"PEMS03.txt",
"PEMS03_dtw_distance.npy",
"PEMS03_spatial_distance.npy"
],
"PEMS04": [
"PEMS04.csv",
"PEMS04.npz",
"PEMS04_dtw_distance.npy",
"PEMS04_spatial_distance.npy"
],
"PEMS07": [
"PEMS07.csv",
"PEMS07.npz",
"PEMS07_dtw_distance.npy",
"PEMS07_spatial_distance.npy"
],
"PEMS08": [
"PEMS08.csv",
"PEMS08.npz",
"PEMS08_dtw_distance.npy",
"PEMS08_spatial_distance.npy"
],
"PEMS-BAY": [
"adj_mx_bay.pkl",
"pems-bay-meta.h5",
"pems-bay.h5"
],
"METR-LA": [
"METR-LA.h5"
],
"SolarEnergy": [
"SolarEnergy.csv"
],
"BeijingAirQuality": ["data.dat", "desc.json"],
"AirQuality": ["data.dat"],
"BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
}