Merge branch 'REPST' of https://github.zhang-heng.com/czzhangheng/TrafficWheel into REPST
This commit is contained in:
commit
5cd81f4d4c
|
|
@ -4,7 +4,7 @@
|
||||||
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
"version": "0.2.0",
|
"version": "0.2.0",
|
||||||
"configurations": [
|
"configurations": [
|
||||||
{
|
{
|
||||||
"name": "STID_PEMS-BAY",
|
"name": "STID_PEMS-BAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
"request": "launch",
|
"request": "launch",
|
||||||
|
|
@ -28,6 +28,14 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/REPST/PEMSD8.yaml"
|
"args": "--config ./config/REPST/PEMSD8.yaml"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "REPST-BJTaxi-InFlow",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/BJTaxi-Inflow.yaml"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "REPST-PEMSBAY",
|
"name": "REPST-PEMSBAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
|
|
@ -36,6 +44,38 @@
|
||||||
"console": "integratedTerminal",
|
"console": "integratedTerminal",
|
||||||
"args": "--config ./config/REPST/PEMS-BAY.yaml"
|
"args": "--config ./config/REPST/PEMS-BAY.yaml"
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"name": "REPST-METR",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/METR-LA.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "REPST-Solar",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/SolarEnergy.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "BeijingAirQuality",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/BeijingAirQuality.yaml"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "AirQuality",
|
||||||
|
"type": "debugpy",
|
||||||
|
"request": "launch",
|
||||||
|
"program": "run.py",
|
||||||
|
"console": "integratedTerminal",
|
||||||
|
"args": "--config ./config/REPST/AirQuality.yaml"
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"name": "AEPSA-PEMSBAY",
|
"name": "AEPSA-PEMSBAY",
|
||||||
"type": "debugpy",
|
"type": "debugpy",
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,11 @@
|
||||||
{
|
{
|
||||||
"python-envs.defaultEnvManager": "ms-python.python:system",
|
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||||
"python-envs.defaultPackageManager": "ms-python.python:pip",
|
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||||
"python-envs.pythonProjects": []
|
"python-envs.pythonProjects": [
|
||||||
|
{
|
||||||
|
"path": "data/SolarEnergy",
|
||||||
|
"envManager": "ms-python.python:conda",
|
||||||
|
"packageManager": "ms-python.python:conda"
|
||||||
|
}
|
||||||
|
]
|
||||||
}
|
}
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
basic:
|
||||||
|
dataset: "AirQuality"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:1"
|
||||||
|
model: "REPST"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: false
|
||||||
|
add_time_in_day: false
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 35
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 6
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 6
|
||||||
|
output_dim: 3
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 3
|
||||||
|
log_step: 1000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
basic:
|
||||||
|
dataset: "BJTaxi-InFlow"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:0"
|
||||||
|
model: "REPST"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: false
|
||||||
|
add_time_in_day: false
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 1024
|
||||||
|
steps_per_day: 48
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 100
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
basic:
|
||||||
|
dataset: "BeijingAirQuality"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:1"
|
||||||
|
model: "REPST"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: false
|
||||||
|
add_time_in_day: false
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 7
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 3
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 3
|
||||||
|
output_dim: 3
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 3
|
||||||
|
log_step: 1000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
basic:
|
||||||
|
dataset: "METR-LA"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:1"
|
||||||
|
model: "REPST"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: true
|
||||||
|
add_time_in_day: true
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 207
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 1000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -0,0 +1,60 @@
|
||||||
|
basic:
|
||||||
|
dataset: "SolarEnergy"
|
||||||
|
mode : "train"
|
||||||
|
device : "cuda:1"
|
||||||
|
model: "REPST"
|
||||||
|
seed: 2023
|
||||||
|
|
||||||
|
data:
|
||||||
|
add_day_in_week: false
|
||||||
|
add_time_in_day: false
|
||||||
|
column_wise: false
|
||||||
|
days_per_week: 7
|
||||||
|
default_graph: true
|
||||||
|
horizon: 24
|
||||||
|
lag: 24
|
||||||
|
normalizer: std
|
||||||
|
num_nodes: 137
|
||||||
|
steps_per_day: 288
|
||||||
|
test_ratio: 0.2
|
||||||
|
tod: false
|
||||||
|
val_ratio: 0.2
|
||||||
|
sample: 1
|
||||||
|
input_dim: 1
|
||||||
|
batch_size: 16
|
||||||
|
|
||||||
|
model:
|
||||||
|
pred_len: 24
|
||||||
|
seq_len: 24
|
||||||
|
patch_len: 6
|
||||||
|
stride: 7
|
||||||
|
dropout: 0.2
|
||||||
|
gpt_layers: 9
|
||||||
|
d_ff: 128
|
||||||
|
gpt_path: ./GPT-2
|
||||||
|
d_model: 64
|
||||||
|
n_heads: 1
|
||||||
|
input_dim: 1
|
||||||
|
word_num: 1000
|
||||||
|
|
||||||
|
train:
|
||||||
|
batch_size: 16
|
||||||
|
early_stop: true
|
||||||
|
early_stop_patience: 15
|
||||||
|
epochs: 100
|
||||||
|
grad_norm: false
|
||||||
|
loss_func: mae
|
||||||
|
lr_decay: true
|
||||||
|
lr_decay_rate: 0.3
|
||||||
|
lr_decay_step: "5,20,40,70"
|
||||||
|
lr_init: 0.003
|
||||||
|
max_grad_norm: 5
|
||||||
|
real_value: true
|
||||||
|
weight_decay: 0
|
||||||
|
debug: false
|
||||||
|
output_dim: 1
|
||||||
|
log_step: 1000
|
||||||
|
plot: false
|
||||||
|
mae_thresh: None
|
||||||
|
mape_thresh: 0.001
|
||||||
|
|
||||||
|
|
@ -7,57 +7,58 @@ def load_st_dataset(config):
|
||||||
# sample = config["data"]["sample"]
|
# sample = config["data"]["sample"]
|
||||||
# output B, N, D
|
# output B, N, D
|
||||||
match dataset:
|
match dataset:
|
||||||
|
case "BeijingAirQuality":
|
||||||
|
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
|
||||||
|
data = np.memmap(data_path, dtype=np.float32, mode='r')
|
||||||
|
L, N, C = 36000, 7, 3
|
||||||
|
data = data.reshape(L, N, C)
|
||||||
|
case "AirQuality":
|
||||||
|
data_path = os.path.join("./data/AirQuality/data.dat")
|
||||||
|
data = np.memmap(data_path, dtype=np.float32, mode='r')
|
||||||
|
L, N, C = 8701,35,6
|
||||||
|
data = data.reshape(L, N, C)
|
||||||
case "PEMS-BAY":
|
case "PEMS-BAY":
|
||||||
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
|
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
|
||||||
with h5py.File(data_path, 'r') as f:
|
with h5py.File(data_path, 'r') as f:
|
||||||
data = f['speed']['block0_values'][:]
|
data = f['speed']['block0_values'][:]
|
||||||
|
case "METR-LA":
|
||||||
|
data_path = os.path.join("./data/METR-LA/METR-LA.h5")
|
||||||
|
with h5py.File(data_path, 'r') as f:
|
||||||
|
data = f['df']['block0_values'][:]
|
||||||
|
case "SolarEnergy":
|
||||||
|
data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv")
|
||||||
|
data = np.loadtxt(data_path, delimiter=",")
|
||||||
case "PEMSD3":
|
case "PEMSD3":
|
||||||
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
|
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
|
||||||
data = np.load(data_path)["data"][
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
:, :, 0
|
|
||||||
]
|
|
||||||
case "PEMSD4":
|
case "PEMSD4":
|
||||||
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
|
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
|
||||||
data = np.load(data_path)["data"][
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
:, :, 0
|
|
||||||
]
|
|
||||||
case "PEMSD7":
|
case "PEMSD7":
|
||||||
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
|
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
|
||||||
data = np.load(data_path)["data"][
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
:, :, 0
|
|
||||||
]
|
|
||||||
case "PEMSD8":
|
case "PEMSD8":
|
||||||
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
|
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
|
||||||
data = np.load(data_path)["data"][
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
:, :, 0
|
|
||||||
]
|
|
||||||
case "PEMSD7(L)":
|
case "PEMSD7(L)":
|
||||||
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
|
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
|
||||||
data = np.load(data_path)["data"][
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
:, :, 0
|
|
||||||
]
|
|
||||||
case "PEMSD7(M)":
|
case "PEMSD7(M)":
|
||||||
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
|
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
|
||||||
data = np.genfromtxt(
|
data = np.genfromtxt(data_path, delimiter=",")
|
||||||
data_path, delimiter=","
|
|
||||||
)
|
|
||||||
case "METR-LA":
|
|
||||||
data_path = os.path.join("./data/METR-LA/METR.h5")
|
|
||||||
with h5py.File(
|
|
||||||
data_path, "r"
|
|
||||||
) as f:
|
|
||||||
data = np.array(f["data"])
|
|
||||||
case "BJ":
|
case "BJ":
|
||||||
data_path = os.path.join("./data/BJ/BJ500.csv")
|
data_path = os.path.join("./data/BJ/BJ500.csv")
|
||||||
data = np.genfromtxt(
|
data = np.genfromtxt(data_path, delimiter=",", skip_header=1)
|
||||||
data_path, delimiter=",", skip_header=1
|
|
||||||
)
|
|
||||||
case "Hainan":
|
case "Hainan":
|
||||||
data_path = os.path.join("./data/Hainan/Hainan.npz")
|
data_path = os.path.join("./data/Hainan/Hainan.npz")
|
||||||
data = np.load(data_path)["data"][:, :, 0]
|
data = np.load(data_path)["data"][:, :, 0]
|
||||||
case "SD":
|
case "SD":
|
||||||
data_path = os.path.join("./data/SD/data.npz")
|
data_path = os.path.join("./data/SD/data.npz")
|
||||||
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
|
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
|
||||||
|
case "BJTaxi-InFlow":
|
||||||
|
data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32)
|
||||||
|
case "BJTaxi-OutFlow":
|
||||||
|
data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32)
|
||||||
case _:
|
case _:
|
||||||
raise ValueError(f"Unsupported dataset: {dataset}")
|
raise ValueError(f"Unsupported dataset: {dataset}")
|
||||||
|
|
||||||
|
|
@ -68,3 +69,16 @@ def load_st_dataset(config):
|
||||||
print("加载 %s 数据集中... " % dataset)
|
print("加载 %s 数据集中... " % dataset)
|
||||||
# return data[::sample]
|
# return data[::sample]
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
def read_BeijingTaxi():
|
||||||
|
files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
|
||||||
|
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
|
||||||
|
all_data = []
|
||||||
|
for file in files:
|
||||||
|
data_path = os.path.join(f"./data/BeijingTaxi/{file}")
|
||||||
|
data = np.load(data_path)
|
||||||
|
all_data.append(data)
|
||||||
|
all_data = np.concatenate(all_data, axis=0)
|
||||||
|
time_num = all_data.shape[0]
|
||||||
|
all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2)
|
||||||
|
return all_data
|
||||||
|
|
@ -13,9 +13,7 @@ class GumbelSoftmax(nn.Module):
|
||||||
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
return self.gumbel_softmax(logits, 1, self.k, self.hard)
|
||||||
|
|
||||||
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
def gumbel_softmax(self, logits, tau=1, k=1000, hard=True):
|
||||||
|
|
||||||
y_soft = F.gumbel_softmax(logits, tau, hard)
|
y_soft = F.gumbel_softmax(logits, tau, hard)
|
||||||
|
|
||||||
if hard:
|
if hard:
|
||||||
# 生成硬掩码
|
# 生成硬掩码
|
||||||
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
_, indices = y_soft.topk(k, dim=0) # 选择Top-K
|
||||||
|
|
|
||||||
|
|
@ -15,13 +15,13 @@ class ReplicationPad1d(nn.Module):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
class TokenEmbedding(nn.Module):
|
class TokenEmbedding(nn.Module):
|
||||||
def __init__(self, c_in, d_model, patch_num, input_dim):
|
def __init__(self, c_in, d_model, patch_num, input_dim, output_dim):
|
||||||
super(TokenEmbedding, self).__init__()
|
super(TokenEmbedding, self).__init__()
|
||||||
padding = 1
|
padding = 1
|
||||||
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
|
||||||
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
kernel_size=3, padding=padding, padding_mode='circular', bias=False)
|
||||||
|
|
||||||
self.confusion_layer = nn.Linear(patch_num * input_dim, 1)
|
self.confusion_layer = nn.Linear(patch_num * input_dim, output_dim)
|
||||||
|
|
||||||
for m in self.modules():
|
for m in self.modules():
|
||||||
if isinstance(m, nn.Conv1d):
|
if isinstance(m, nn.Conv1d):
|
||||||
|
|
@ -37,22 +37,20 @@ class TokenEmbedding(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class PatchEmbedding(nn.Module):
|
class PatchEmbedding(nn.Module):
|
||||||
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim):
|
def __init__(self, d_model, patch_len, stride, dropout, patch_num, input_dim, output_dim):
|
||||||
super(PatchEmbedding, self).__init__()
|
super(PatchEmbedding, self).__init__()
|
||||||
# Patching
|
# Patching
|
||||||
self.patch_len = patch_len
|
self.patch_len = patch_len
|
||||||
self.stride = stride
|
self.stride = stride
|
||||||
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
self.padding_patch_layer = ReplicationPad1d((0, stride))
|
||||||
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim)
|
self.value_embedding = TokenEmbedding(patch_len, d_model, patch_num, input_dim, output_dim)
|
||||||
self.dropout = nn.Dropout(dropout)
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
||||||
n_vars = x.shape[2]
|
n_vars = x.shape[2]
|
||||||
x = self.padding_patch_layer(x)
|
x = self.padding_patch_layer(x)
|
||||||
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
|
||||||
x_value_embed = self.value_embedding(x)
|
x_value_embed = self.value_embedding(x)
|
||||||
|
|
||||||
return self.dropout(x_value_embed), n_vars
|
return self.dropout(x_value_embed), n_vars
|
||||||
|
|
||||||
class ReprogrammingLayer(nn.Module):
|
class ReprogrammingLayer(nn.Module):
|
||||||
|
|
@ -84,13 +82,9 @@ class ReprogrammingLayer(nn.Module):
|
||||||
|
|
||||||
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
def reprogramming(self, target_embedding, source_embedding, value_embedding):
|
||||||
B, L, H, E = target_embedding.shape
|
B, L, H, E = target_embedding.shape
|
||||||
|
|
||||||
scale = 1. / sqrt(E)
|
scale = 1. / sqrt(E)
|
||||||
|
|
||||||
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)
|
||||||
|
|
||||||
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
A = self.dropout(torch.softmax(scale * scores, dim=-1))
|
||||||
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)
|
||||||
|
|
||||||
return reprogramming_embedding
|
return reprogramming_embedding
|
||||||
|
|
||||||
|
|
@ -19,6 +19,7 @@ class repst(nn.Module):
|
||||||
self.gpt_layers = configs['gpt_layers']
|
self.gpt_layers = configs['gpt_layers']
|
||||||
self.d_ff = configs['d_ff']
|
self.d_ff = configs['d_ff']
|
||||||
self.gpt_path = configs['gpt_path']
|
self.gpt_path = configs['gpt_path']
|
||||||
|
self.output_dim = configs.get('output_dim', 1)
|
||||||
|
|
||||||
self.word_choice = GumbelSoftmax(configs['word_num'])
|
self.word_choice = GumbelSoftmax(configs['word_num'])
|
||||||
|
|
||||||
|
|
@ -31,7 +32,7 @@ class repst(nn.Module):
|
||||||
self.head_nf = self.d_ff * self.patch_nums
|
self.head_nf = self.d_ff * self.patch_nums
|
||||||
|
|
||||||
# 词嵌入
|
# 词嵌入
|
||||||
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim)
|
self.patch_embedding = PatchEmbedding(self.d_model, self.patch_len, self.stride, self.dropout, self.patch_nums, self.input_dim, self.output_dim)
|
||||||
|
|
||||||
# GPT2初始化
|
# GPT2初始化
|
||||||
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
self.gpts = GPT2Model.from_pretrained(self.gpt_path, output_attentions=True, output_hidden_states=True)
|
||||||
|
|
@ -41,12 +42,12 @@ class repst(nn.Module):
|
||||||
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
self.word_embeddings = self.gpts.get_input_embeddings().weight.to(self.device)
|
||||||
self.vocab_size = self.word_embeddings.shape[0]
|
self.vocab_size = self.word_embeddings.shape[0]
|
||||||
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
self.mapping_layer = nn.Linear(self.vocab_size, 1)
|
||||||
self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_keys, self.d_llm)
|
self.reprogramming_layer = ReprogrammingLayer(self.d_model * self.output_dim, self.n_heads, self.d_keys, self.d_llm)
|
||||||
|
|
||||||
self.out_mlp = nn.Sequential(
|
self.out_mlp = nn.Sequential(
|
||||||
nn.Linear(self.d_llm, 128),
|
nn.Linear(self.d_llm, 128),
|
||||||
nn.ReLU(),
|
nn.ReLU(),
|
||||||
nn.Linear(128, self.pred_len)
|
nn.Linear(128, self.pred_len * self.output_dim)
|
||||||
)
|
)
|
||||||
|
|
||||||
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
for i, (name, param) in enumerate(self.gpts.named_parameters()):
|
||||||
|
|
@ -62,7 +63,7 @@ class repst(nn.Module):
|
||||||
torch.nn.init.zeros_(module.bias)
|
torch.nn.init.zeros_(module.bias)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = x[..., :1]
|
x = x[..., :self.input_dim]
|
||||||
x_enc = rearrange(x, 'b t n c -> b n c t')
|
x_enc = rearrange(x, 'b t n c -> b n c t')
|
||||||
enc_out, n_vars = self.patch_embedding(x_enc)
|
enc_out, n_vars = self.patch_embedding(x_enc)
|
||||||
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)
|
||||||
|
|
@ -72,32 +73,11 @@ class repst(nn.Module):
|
||||||
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
|
||||||
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
enc_out = self.gpts(inputs_embeds=enc_out).last_hidden_state
|
||||||
|
|
||||||
dec_out = self.out_mlp(enc_out)
|
dec_out = self.out_mlp(enc_out) #[B, N, T*C]
|
||||||
outputs = dec_out.unsqueeze(dim=-1)
|
|
||||||
outputs = outputs.repeat(1, 1, 1, n_vars)
|
B, N, _ = dec_out.shape
|
||||||
outputs = outputs.permute(0,2,1,3)
|
outputs = dec_out.view(B, N, self.pred_len, self.output_dim)
|
||||||
|
outputs = outputs.permute(0, 2, 1, 3) # B, T, N, C
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
configs = {
|
|
||||||
'device': 'cuda:0',
|
|
||||||
'pred_len': 24,
|
|
||||||
'seq_len': 24,
|
|
||||||
'patch_len': 6,
|
|
||||||
'stride': 7,
|
|
||||||
'dropout': 0.2,
|
|
||||||
'gpt_layers': 9,
|
|
||||||
'd_ff': 128,
|
|
||||||
'gpt_path': './GPT-2',
|
|
||||||
'd_model': 64,
|
|
||||||
'n_heads': 1,
|
|
||||||
'input_dim': 1
|
|
||||||
}
|
|
||||||
model = repst(configs)
|
|
||||||
x = torch.randn(16, 24, 325, 1)
|
|
||||||
y = model(x)
|
|
||||||
|
|
||||||
print(y.shape)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,3 +12,4 @@ notebook
|
||||||
torchcde
|
torchcde
|
||||||
einops
|
einops
|
||||||
transformers
|
transformers
|
||||||
|
py7zr
|
||||||
2
run.py
2
run.py
|
|
@ -14,6 +14,8 @@ def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
args = init.init_device(args)
|
args = init.init_device(args)
|
||||||
init.init_seed(args["basic"]["seed"])
|
init.init_seed(args["basic"]["seed"])
|
||||||
|
|
||||||
|
# Load model
|
||||||
model = init.init_model(args)
|
model = init.init_model(args)
|
||||||
|
|
||||||
# Load dataset
|
# Load dataset
|
||||||
|
|
|
||||||
|
|
@ -203,7 +203,7 @@ class Trainer:
|
||||||
self.stats.record_step_time(step_time, mode)
|
self.stats.record_step_time(step_time, mode)
|
||||||
|
|
||||||
# 累积损失和预测结果
|
# 累积损失和预测结果
|
||||||
total_loss += d_loss.item()
|
total_loss += loss.item()
|
||||||
y_pred.append(d_output.detach().cpu())
|
y_pred.append(d_output.detach().cpu())
|
||||||
y_true.append(d_label.detach().cpu())
|
y_true.append(d_label.detach().cpu())
|
||||||
|
|
||||||
|
|
@ -316,13 +316,9 @@ class Trainer:
|
||||||
|
|
||||||
def _log_model_params(self):
|
def _log_model_params(self):
|
||||||
"""输出模型可训练参数数量"""
|
"""输出模型可训练参数数量"""
|
||||||
try:
|
total_params = sum( p.numel() for p in self.model.parameters() if p.requires_grad)
|
||||||
total_params = sum(
|
self.logger.info(f"Trainable params: {total_params}")
|
||||||
p.numel() for p in self.model.parameters() if p.requires_grad
|
|
||||||
)
|
|
||||||
self.logger.info(f"Trainable params: {total_params}")
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _finalize_training(self, best_model, best_test_model):
|
def _finalize_training(self, best_model, best_test_model):
|
||||||
self.model.load_state_dict(best_model)
|
self.model.load_state_dict(best_model)
|
||||||
|
|
@ -353,35 +349,26 @@ class Trainer:
|
||||||
for data, target in data_loader:
|
for data, target in data_loader:
|
||||||
label = target[..., : args["output_dim"]]
|
label = target[..., : args["output_dim"]]
|
||||||
output = model(data)
|
output = model(data)
|
||||||
y_pred.append(output)
|
y_pred.append(output.detach().cpu())
|
||||||
y_true.append(label)
|
y_true.append(label.detach().cpu())
|
||||||
|
|
||||||
# 合并所有批次的预测结果
|
|
||||||
if args["real_value"]:
|
d_y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
||||||
y_pred = scaler.inverse_transform(torch.cat(y_pred, dim=0))
|
d_y_true = scaler.inverse_transform(torch.cat(y_true, dim=0))
|
||||||
else:
|
|
||||||
y_pred = torch.cat(y_pred, dim=0)
|
|
||||||
y_true = torch.cat(y_true, dim=0)
|
|
||||||
|
|
||||||
# 计算并记录每个时间步的指标
|
# 计算并记录每个时间步的指标
|
||||||
for t in range(y_true.shape[1]):
|
for t in range(d_y_true.shape[1]):
|
||||||
mae, rmse, mape = all_metrics(
|
mae, rmse, mape = all_metrics(
|
||||||
y_pred[:, t, ...],
|
d_y_pred[:, t, ...],
|
||||||
y_true[:, t, ...],
|
d_y_true[:, t, ...],
|
||||||
args["mae_thresh"],
|
args["mae_thresh"],
|
||||||
args["mape_thresh"],
|
args["mape_thresh"],
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
f"Horizon {t + 1:02d}, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 计算并记录平均指标
|
# 计算并记录平均指标
|
||||||
mae, rmse, mape = all_metrics(
|
mae, rmse, mape = all_metrics(d_y_pred, d_y_true, args["mae_thresh"], args["mape_thresh"])
|
||||||
y_pred, y_true, args["mae_thresh"], args["mape_thresh"]
|
logger.info( f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}")
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
f"Average Horizon, MAE: {mae:.4f}, RMSE: {rmse:.4f}, MAPE: {mape:.4f}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _compute_sampling_threshold(global_step, k):
|
def _compute_sampling_threshold(global_step, k):
|
||||||
|
|
|
||||||
|
|
@ -1,204 +1,191 @@
|
||||||
import os
|
import os, json, shutil, requests
|
||||||
import requests
|
from urllib.parse import urlsplit
|
||||||
import zipfile
|
|
||||||
import shutil
|
|
||||||
import kagglehub # 假设 kagglehub 是一个可用的库
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import kagglehub
|
||||||
# 定义文件完整性信息的字典
|
import py7zr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------- 1. 检测完整性 ----------
|
||||||
|
def detect_data_integrity(data_dir, expected):
|
||||||
|
missing_list = []
|
||||||
|
if not os.path.isdir(data_dir):
|
||||||
|
# 如果数据目录不存在,则所有数据集都缺失
|
||||||
|
missing_list.extend(expected.keys())
|
||||||
|
# 标记adj也缺失
|
||||||
|
missing_list.append("adj")
|
||||||
|
return missing_list
|
||||||
|
|
||||||
|
# 检查adj相关文件(距离矩阵文件)
|
||||||
|
has_missing_adj = False
|
||||||
|
for folder, files in expected.items():
|
||||||
|
folder_path = os.path.join(data_dir, folder)
|
||||||
|
if os.path.isdir(folder_path):
|
||||||
|
existing = set(os.listdir(folder_path))
|
||||||
|
for f in files:
|
||||||
|
if f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||||
|
has_missing_adj = True
|
||||||
|
break
|
||||||
|
if has_missing_adj:
|
||||||
|
missing_list.append("adj")
|
||||||
|
|
||||||
|
# 检查数据集主文件
|
||||||
|
for folder, files in expected.items():
|
||||||
|
folder_path = os.path.join(data_dir, folder)
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
missing_list.append(folder)
|
||||||
|
continue
|
||||||
|
|
||||||
|
existing = set(os.listdir(folder_path))
|
||||||
|
has_missing_file = False
|
||||||
|
|
||||||
|
for f in files:
|
||||||
|
# 跳过距离矩阵文件,已经在上面检查过了
|
||||||
|
if not f.endswith(("_dtw_distance.npy", "_spatial_distance.npy")) and f not in existing:
|
||||||
|
has_missing_file = True
|
||||||
|
|
||||||
|
if has_missing_file and folder not in missing_list:
|
||||||
|
missing_list.append(folder)
|
||||||
|
|
||||||
|
# print(f"缺失数据集:{missing_list}")
|
||||||
|
return missing_list
|
||||||
|
|
||||||
|
# ---------- 2. 下载 7z 并解压 ----------
|
||||||
|
def download_and_extract(url, dst_dir, max_retries=3):
|
||||||
|
os.makedirs(dst_dir, exist_ok=True)
|
||||||
|
filename = os.path.basename(urlsplit(url).path) or "download.7z"
|
||||||
|
file_path = os.path.join(dst_dir, filename)
|
||||||
|
for attempt in range(1, max_retries+1):
|
||||||
|
try:
|
||||||
|
# 下载
|
||||||
|
with requests.get(url, stream=True, timeout=30) as r:
|
||||||
|
r.raise_for_status()
|
||||||
|
total = int(r.headers.get("content-length",0))
|
||||||
|
with open(file_path,"wb") as f, tqdm(total=total, unit="B", unit_scale=True, desc=filename) as bar:
|
||||||
|
for chunk in r.iter_content(8192):
|
||||||
|
f.write(chunk)
|
||||||
|
bar.update(len(chunk))
|
||||||
|
# 解压 7z
|
||||||
|
with py7zr.SevenZipFile(file_path, mode='r') as archive:
|
||||||
|
archive.extractall(path=dst_dir)
|
||||||
|
os.remove(file_path)
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
if attempt==max_retries: raise RuntimeError("下载或解压失败")
|
||||||
|
print("错误,重试中...", e)
|
||||||
|
|
||||||
|
# ---------- 3. 下载 Kaggle 数据 ----------
|
||||||
|
def download_kaggle_data(base_dir, dataset):
|
||||||
|
try:
|
||||||
|
print(f"Downloading kaggle dataset : {dataset}")
|
||||||
|
path = kagglehub.dataset_download(dataset)
|
||||||
|
shutil.copytree(path, os.path.join(base_dir,"data"), dirs_exist_ok=True)
|
||||||
|
except Exception as e:
|
||||||
|
print("Kaggle 下载失败:", dataset, e)
|
||||||
|
|
||||||
|
# ---------- 4. 下载 GitHub 数据 ----------
|
||||||
|
def download_github_data(file_path, save_dir):
|
||||||
|
if not os.path.exists(save_dir):
|
||||||
|
os.makedirs(save_dir)
|
||||||
|
raw_url = f"https://ghfast.top/https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||||
|
# raw_url = f"https://raw.githubusercontent.com/prabinrath/Traffic-Flow-Prediction/main/{file_path}"
|
||||||
|
response = requests.head(raw_url, allow_redirects=True)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"Failed to get file size for {raw_url}. Status code:", response.status_code)
|
||||||
|
return
|
||||||
|
|
||||||
|
file_size = int(response.headers.get('Content-Length', 0))
|
||||||
|
response = requests.get(raw_url, stream=True, allow_redirects=True)
|
||||||
|
file_name = os.path.basename(file_path)
|
||||||
|
file_path_to_save = os.path.join(save_dir, file_name)
|
||||||
|
with open(file_path_to_save, 'wb') as f:
|
||||||
|
with tqdm(total=file_size, unit='B', unit_scale=True, desc=f"Downloading {file_name}") as pbar:
|
||||||
|
for chunk in response.iter_content(chunk_size=1024):
|
||||||
|
if chunk:
|
||||||
|
f.write(chunk)
|
||||||
|
pbar.update(len(chunk))
|
||||||
|
|
||||||
|
# ---------- 5. 整理目录 ----------
|
||||||
|
def rearrange_dir():
|
||||||
|
data_dir = os.path.join(os.getcwd(), "data")
|
||||||
|
nested = os.path.join(data_dir,"data")
|
||||||
|
if os.path.isdir(nested):
|
||||||
|
for item in os.listdir(nested):
|
||||||
|
src,dst = os.path.join(nested,item), os.path.join(data_dir,item)
|
||||||
|
if os.path.isdir(src):
|
||||||
|
shutil.copytree(src, dst, dirs_exist_ok=True) # 更新已存在的目录
|
||||||
|
else:
|
||||||
|
shutil.copy2(src, dst)
|
||||||
|
shutil.rmtree(nested)
|
||||||
|
|
||||||
|
for kw,tgt in [("bay","PEMS-BAY"),("metr","METR-LA")]:
|
||||||
|
dst = os.path.join(data_dir,tgt); os.makedirs(dst,exist_ok=True)
|
||||||
|
for f in os.listdir(data_dir):
|
||||||
|
if kw in f.lower() and f.endswith((".h5",".pkl")):
|
||||||
|
shutil.move(os.path.join(data_dir,f), os.path.join(dst,f))
|
||||||
|
|
||||||
|
solar = os.path.join(data_dir,"solar-energy")
|
||||||
|
if os.path.isdir(solar):
|
||||||
|
dst = os.path.join(data_dir,"SolarEnergy"); os.makedirs(dst,exist_ok=True)
|
||||||
|
csv = os.path.join(solar,"solar_AL.csv")
|
||||||
|
if os.path.isfile(csv): shutil.move(csv, os.path.join(dst,"SolarEnergy.csv"))
|
||||||
|
shutil.rmtree(solar)
|
||||||
|
|
||||||
|
# ---------- 6. 主流程 ----------
|
||||||
def check_and_download_data():
|
def check_and_download_data():
|
||||||
"""
|
# 加载结构文件,检测缺失数据集
|
||||||
检查 data 文件夹的完整性,并根据缺失文件类型下载相应数据。
|
cwd = os.getcwd()
|
||||||
"""
|
data_dir = os.path.join(cwd,"data")
|
||||||
current_working_dir = os.getcwd() # 获取当前工作目录
|
with open("utils/dataset.json", "r", encoding="utf-8") as f:
|
||||||
data_dir = os.path.join(
|
file_tree = json.load(f)
|
||||||
current_working_dir, "data"
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
) # 假设 data 文件夹在当前工作目录下
|
# print(f"缺失数据集:{missing_list}")
|
||||||
|
|
||||||
expected_structure = {
|
# 检查并下载adj数据
|
||||||
"PEMS03": [
|
if "adj" in missing_list:
|
||||||
"PEMS03.csv",
|
download_and_extract("http://code.zhang-heng.com/static/adj.7z", data_dir)
|
||||||
"PEMS03.npz",
|
# 下载后从缺失列表中移除adj
|
||||||
"PEMS03.txt",
|
missing_list.remove("adj")
|
||||||
"PEMS03_dtw_distance.npy",
|
|
||||||
"PEMS03_spatial_distance.npy",
|
# 检查BeijingAirQuality和AirQuality
|
||||||
],
|
if "BeijingAirQuality" in missing_list or "AirQuality" in missing_list:
|
||||||
"PEMS04": [
|
download_and_extract("http://code.zhang-heng.com/static/BeijingAirQuality.7z", data_dir)
|
||||||
"PEMS04.csv",
|
# 下载后更新缺失列表
|
||||||
"PEMS04.npz",
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
"PEMS04_dtw_distance.npy",
|
|
||||||
"PEMS04_spatial_distance.npy",
|
# 检查并下载TaxiBJ数据
|
||||||
],
|
if "TaxiBJ" in missing_list:
|
||||||
"PEMS07": [
|
taxi_bj_floder = os.path.join(data_dir, "BeijingTaxi")
|
||||||
"PEMS07.csv",
|
taxibj_files = ['TaxiBJ2013.npy', 'TaxiBJ2014.npy', 'TaxiBJ2015.npy', 'TaxiBJ2016_1.npy', 'TaxiBJ2016_2.npy']
|
||||||
"PEMS07.npz",
|
for file in taxibj_files:
|
||||||
"PEMS07_dtw_distance.npy",
|
file_path = f"Datasets/TaxiBJ/{file}"
|
||||||
"PEMS07_spatial_distance.npy",
|
download_github_data(file_path, taxi_bj_floder)
|
||||||
],
|
# 下载后更新缺失列表
|
||||||
"PEMS08": [
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
"PEMS08.csv",
|
|
||||||
"PEMS08.npz",
|
# 检查并下载pems, bay, metr-la, solar-energy数据
|
||||||
"PEMS08_dtw_distance.npy",
|
kaggle_map = {
|
||||||
"PEMS08_spatial_distance.npy",
|
"PEMS03": "elmahy/pems-dataset",
|
||||||
],
|
"PEMS04": "elmahy/pems-dataset",
|
||||||
"PEMS-BAY": [
|
"PEMS07": "elmahy/pems-dataset",
|
||||||
"adj_mx_bay.pkl",
|
"PEMS08": "elmahy/pems-dataset",
|
||||||
"pems-bay-meta.h5",
|
"PEMS-BAY": "scchuy/pemsbay",
|
||||||
"pems-bay.h5"
|
"METR-LA": "annnnguyen/metr-la-dataset",
|
||||||
]
|
"SolarEnergy": "wangshaoqi/solar-energy"
|
||||||
}
|
}
|
||||||
|
|
||||||
current_dir = os.getcwd() # 获取当前工作目录
|
# 先对kaggle下载地址进行去重,避免重复下载相同的数据集
|
||||||
missing_adj = False
|
downloaded_kaggle_datasets = set()
|
||||||
missing_main_files = False
|
|
||||||
|
|
||||||
# 检查 data 文件夹是否存在
|
for dataset, kaggle_ds in kaggle_map.items():
|
||||||
if not os.path.exists(data_dir) or not os.path.isdir(data_dir):
|
if dataset in missing_list and kaggle_ds not in downloaded_kaggle_datasets:
|
||||||
# print(f"目录 {data_dir} 不存在。")
|
download_kaggle_data(cwd, kaggle_ds)
|
||||||
print("正在下载所有必要的数据文件...")
|
# 将已下载的数据集添加到集合中
|
||||||
missing_adj = True
|
downloaded_kaggle_datasets.add(kaggle_ds)
|
||||||
missing_main_files = True
|
# 下载一个数据集后更新缺失列表
|
||||||
else:
|
missing_list = detect_data_integrity(data_dir, file_tree)
|
||||||
# 遍历预期的文件结构
|
|
||||||
for subfolder, expected_files in expected_structure.items():
|
|
||||||
subfolder_path = os.path.join(data_dir, subfolder)
|
|
||||||
|
|
||||||
# 检查子文件夹是否存在
|
|
||||||
if not os.path.exists(subfolder_path) or not os.path.isdir(subfolder_path):
|
|
||||||
# print(f"子文件夹 {subfolder} 不存在。")
|
|
||||||
missing_main_files = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
# 获取子文件夹中的实际文件列表
|
|
||||||
actual_files = os.listdir(subfolder_path)
|
|
||||||
|
|
||||||
# 检查是否缺少文件
|
|
||||||
for expected_file in expected_files:
|
|
||||||
if expected_file not in actual_files:
|
|
||||||
# print(f"子文件夹 {subfolder} 中缺少文件 {expected_file}。")
|
|
||||||
if (
|
|
||||||
"_dtw_distance.npy" in expected_file
|
|
||||||
or "_spatial_distance.npy" in expected_file
|
|
||||||
):
|
|
||||||
missing_adj = True
|
|
||||||
else:
|
|
||||||
missing_main_files = True
|
|
||||||
|
|
||||||
# 根据缺失文件类型调用下载逻辑
|
|
||||||
if missing_adj:
|
|
||||||
download_adj_data(current_dir)
|
|
||||||
if missing_main_files:
|
|
||||||
download_kaggle_data(current_dir, 'elmahy/pems-dataset')
|
|
||||||
download_kaggle_data(current_dir, 'scchuy/pemsbay')
|
|
||||||
|
|
||||||
rearrange_dir()
|
rearrange_dir()
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
if __name__=="__main__":
|
||||||
def download_adj_data(current_dir, max_retries=3):
|
|
||||||
"""
|
|
||||||
下载并解压 adj.zip 文件,并显示下载进度条。
|
|
||||||
如果下载失败,最多重试 max_retries 次。
|
|
||||||
"""
|
|
||||||
url = "http://code.zhang-heng.com/static/adj.zip"
|
|
||||||
retries = 0
|
|
||||||
|
|
||||||
while retries <= max_retries:
|
|
||||||
try:
|
|
||||||
print(f"正在从 {url} 下载邻接矩阵文件...")
|
|
||||||
response = requests.get(url, stream=True)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
total_size = int(response.headers.get("content-length", 0))
|
|
||||||
block_size = 1024 # 1KB
|
|
||||||
t = tqdm(total=total_size, unit="B", unit_scale=True, desc="下载进度")
|
|
||||||
|
|
||||||
zip_file_path = os.path.join(current_dir, "adj.zip")
|
|
||||||
with open(zip_file_path, "wb") as f:
|
|
||||||
for data in response.iter_content(block_size):
|
|
||||||
f.write(data)
|
|
||||||
t.update(len(data))
|
|
||||||
t.close()
|
|
||||||
|
|
||||||
# print("下载完成,文件已保存到:", zip_file_path)
|
|
||||||
|
|
||||||
if os.path.exists(zip_file_path):
|
|
||||||
with zipfile.ZipFile(zip_file_path, "r") as zip_ref:
|
|
||||||
zip_ref.extractall(current_dir)
|
|
||||||
# print("数据集已解压到:", current_dir)
|
|
||||||
os.remove(zip_file_path) # 删除zip文件
|
|
||||||
else:
|
|
||||||
print("未找到下载的zip文件,跳过解压。")
|
|
||||||
break # 下载成功,退出循环
|
|
||||||
else:
|
|
||||||
print(f"下载失败,状态码: {response.status_code}。请检查链接是否有效。")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"下载或解压数据集时出错: {e}")
|
|
||||||
print("如果链接无效,请检查URL的合法性或稍后重试。")
|
|
||||||
|
|
||||||
retries += 1
|
|
||||||
if retries > max_retries:
|
|
||||||
raise Exception(
|
|
||||||
f"下载失败,已达到最大重试次数({max_retries}次)。请检查链接或网络连接。"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def download_kaggle_data(current_dir, kaggle_path):
|
|
||||||
"""
|
|
||||||
下载 KaggleHub 数据集,并将数据直接移动到当前工作目录的 data 文件夹。
|
|
||||||
如果目标文件夹已存在,会覆盖冲突的文件。
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
print(f"正在下载 {kaggle_path} 数据集...")
|
|
||||||
path = kagglehub.dataset_download(kaggle_path)
|
|
||||||
# print("Path to KaggleHub dataset files:", path)
|
|
||||||
|
|
||||||
if os.path.exists(path):
|
|
||||||
destination_path = os.path.join(current_dir, "data")
|
|
||||||
# 使用 shutil.copytree 将文件夹内容直接放在 data 文件夹下,覆盖冲突的文件
|
|
||||||
shutil.copytree(path, destination_path, dirs_exist_ok=True)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"下载或处理 KaggleHub 数据集时出错: {e}")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def rearrange_dir():
|
|
||||||
"""
|
|
||||||
将 data/data 中的文件合并到上级目录,并删除 data/data 目录。
|
|
||||||
"""
|
|
||||||
data_dir = os.path.join(os.getcwd(), "data")
|
|
||||||
nested_data_dir = os.path.join(data_dir, "data")
|
|
||||||
|
|
||||||
if os.path.exists(nested_data_dir) and os.path.isdir(nested_data_dir):
|
|
||||||
for item in os.listdir(nested_data_dir):
|
|
||||||
source_path = os.path.join(nested_data_dir, item)
|
|
||||||
destination_path = os.path.join(data_dir, item)
|
|
||||||
|
|
||||||
if os.path.isdir(source_path):
|
|
||||||
shutil.copytree(source_path, destination_path, dirs_exist_ok=True)
|
|
||||||
else:
|
|
||||||
shutil.copy2(source_path, destination_path)
|
|
||||||
|
|
||||||
shutil.rmtree(nested_data_dir)
|
|
||||||
# print(f"已合并 {nested_data_dir} 到 {data_dir},并删除嵌套目录。")
|
|
||||||
|
|
||||||
# 将带有 "bay" 的文件移动到 PEMS-BAY 文件夹
|
|
||||||
pems_bay_dir = os.path.join(data_dir, "PEMS-BAY")
|
|
||||||
os.makedirs(pems_bay_dir, exist_ok=True)
|
|
||||||
|
|
||||||
for item in os.listdir(data_dir):
|
|
||||||
if "bay" in item.lower() and (item.endswith(".pkl") or item.endswith(".h5")):
|
|
||||||
source_path = os.path.join(data_dir, item)
|
|
||||||
destination_path = os.path.join(pems_bay_dir, item)
|
|
||||||
shutil.move(source_path, destination_path)
|
|
||||||
|
|
||||||
# print(f"已将带有 'bay' 的文件移动到 {pems_bay_dir}。")
|
|
||||||
|
|
||||||
|
|
||||||
# 主程序
|
|
||||||
if __name__ == "__main__":
|
|
||||||
check_and_download_data()
|
check_and_download_data()
|
||||||
# rearrange_dir()
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,41 @@
|
||||||
|
{
|
||||||
|
"PEMS03": [
|
||||||
|
"PEMS03.csv",
|
||||||
|
"PEMS03.npz",
|
||||||
|
"PEMS03.txt",
|
||||||
|
"PEMS03_dtw_distance.npy",
|
||||||
|
"PEMS03_spatial_distance.npy"
|
||||||
|
],
|
||||||
|
"PEMS04": [
|
||||||
|
"PEMS04.csv",
|
||||||
|
"PEMS04.npz",
|
||||||
|
"PEMS04_dtw_distance.npy",
|
||||||
|
"PEMS04_spatial_distance.npy"
|
||||||
|
],
|
||||||
|
"PEMS07": [
|
||||||
|
"PEMS07.csv",
|
||||||
|
"PEMS07.npz",
|
||||||
|
"PEMS07_dtw_distance.npy",
|
||||||
|
"PEMS07_spatial_distance.npy"
|
||||||
|
],
|
||||||
|
"PEMS08": [
|
||||||
|
"PEMS08.csv",
|
||||||
|
"PEMS08.npz",
|
||||||
|
"PEMS08_dtw_distance.npy",
|
||||||
|
"PEMS08_spatial_distance.npy"
|
||||||
|
],
|
||||||
|
"PEMS-BAY": [
|
||||||
|
"adj_mx_bay.pkl",
|
||||||
|
"pems-bay-meta.h5",
|
||||||
|
"pems-bay.h5"
|
||||||
|
],
|
||||||
|
"METR-LA": [
|
||||||
|
"METR-LA.h5"
|
||||||
|
],
|
||||||
|
"SolarEnergy": [
|
||||||
|
"SolarEnergy.csv"
|
||||||
|
],
|
||||||
|
"BeijingAirQuality": ["data.dat", "desc.json"],
|
||||||
|
"AirQuality": ["data.dat"],
|
||||||
|
"BeijingTaxi": ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue