71 lines
2.4 KiB
Python
71 lines
2.4 KiB
Python
import os
|
|
import numpy as np
|
|
import h5py
|
|
|
|
def load_st_dataset(config):
|
|
dataset = config["basic"]["dataset"]
|
|
# sample = config["data"]["sample"]
|
|
# output B, N, D
|
|
match dataset:
|
|
case "PEMS-BAY":
|
|
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
|
|
with h5py.File(data_path, 'r') as f:
|
|
data = f['speed']['block0_values'][:]
|
|
case "PEMSD3":
|
|
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
|
|
data = np.load(data_path)["data"][
|
|
:, :, 0
|
|
]
|
|
case "PEMSD4":
|
|
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
|
|
data = np.load(data_path)["data"][
|
|
:, :, 0
|
|
]
|
|
case "PEMSD7":
|
|
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
|
|
data = np.load(data_path)["data"][
|
|
:, :, 0
|
|
]
|
|
case "PEMSD8":
|
|
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
|
|
data = np.load(data_path)["data"][
|
|
:, :, 0
|
|
]
|
|
case "PEMSD7(L)":
|
|
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
|
|
data = np.load(data_path)["data"][
|
|
:, :, 0
|
|
]
|
|
case "PEMSD7(M)":
|
|
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
|
|
data = np.genfromtxt(
|
|
data_path, delimiter=","
|
|
)
|
|
case "METR-LA":
|
|
data_path = os.path.join("./data/METR-LA/METR.h5")
|
|
with h5py.File(
|
|
data_path, "r"
|
|
) as f:
|
|
data = np.array(f["data"])
|
|
case "BJ":
|
|
data_path = os.path.join("./data/BJ/BJ500.csv")
|
|
data = np.genfromtxt(
|
|
data_path, delimiter=",", skip_header=1
|
|
)
|
|
case "Hainan":
|
|
data_path = os.path.join("./data/Hainan/Hainan.npz")
|
|
data = np.load(data_path)["data"][:, :, 0]
|
|
case "SD":
|
|
data_path = os.path.join("./data/SD/data.npz")
|
|
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
|
|
case _:
|
|
raise ValueError(f"Unsupported dataset: {dataset}")
|
|
|
|
# Ensure data shape compatibility
|
|
if len(data.shape) == 2:
|
|
data = np.expand_dims(data, axis=-1)
|
|
|
|
print("加载 %s 数据集中... " % dataset)
|
|
# return data[::sample]
|
|
return data
|