import os import numpy as np import h5py def load_st_dataset(config): dataset = config["basic"]["dataset"] # sample = config["data"]["sample"] # output B, N, D match dataset: case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: data = f['speed']['block0_values'][:] case "METR-LA": data_path = os.path.join("./data/METR-LA/METR-LA.h5") with h5py.File(data_path, 'r') as f: data = f['df']['block0_values'][:] case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD4": data_path = os.path.join("./data/PEMS04/PEMS04.npz") data = np.load(data_path)["data"][ :, :, 0 ] case "PEMSD7": data_path = os.path.join("./data/PEMS07/PEMS07.npz") data = np.load(data_path)["data"][ :, :, 0 ] case "PEMSD8": data_path = os.path.join("./data/PEMS08/PEMS08.npz") data = np.load(data_path)["data"][ :, :, 0 ] case "PEMSD7(L)": data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") data = np.load(data_path)["data"][ :, :, 0 ] case "PEMSD7(M)": data_path = os.path.join("./data/PEMS07(M)/V_228.csv") data = np.genfromtxt( data_path, delimiter="," ) case "METR-LA": data_path = os.path.join("./data/METR-LA/METR.h5") with h5py.File( data_path, "r" ) as f: data = np.array(f["data"]) case "BJ": data_path = os.path.join("./data/BJ/BJ500.csv") data = np.genfromtxt( data_path, delimiter=",", skip_header=1 ) case "Hainan": data_path = os.path.join("./data/Hainan/Hainan.npz") data = np.load(data_path)["data"][:, :, 0] case "SD": data_path = os.path.join("./data/SD/data.npz") data = np.load(data_path)["data"][:, :, 0].astype(np.float32) case _: raise ValueError(f"Unsupported dataset: {dataset}") # Ensure data shape compatibility if len(data.shape) == 2: data = np.expand_dims(data, axis=-1) print("加载 %s 数据集中... " % dataset) # return data[::sample] return data