import os import numpy as np import h5py def load_st_dataset(config): dataset = config["basic"]["dataset"] loaders = { "BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3), "AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6), "PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")), "METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")), "SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","), "PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"), "PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"), "PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"), "PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"), "PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"), "PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","), "BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1), "Hainan": lambda: _npz("./data/Hainan/Hainan.npz"), "SD": lambda: _npz("./data/SD/data.npz", cast=True), "BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32), "BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32), "NYCBike-InFlow": lambda: _nyc_bike(0), "NYCBike-OutFlow": lambda: _nyc_bike(1), } if dataset not in loaders: raise ValueError(f"Unsupported dataset: {dataset}") data = loaders[dataset]() if data.ndim == 2: data = data[..., None] print(f"加载 {dataset} 数据集中... ") return data # ---------------- helpers ---------------- def _memmap(path, L, N, C): data = np.memmap(path, dtype=np.float32, mode="r") return data.reshape(L, N, C) def _h5(path, keys): with h5py.File(path, "r") as f: return f[keys[0]][keys[1]][:] def _npz(path, cast=False): data = np.load(path)["data"][:, :, 0] return data.astype(np.float32) if cast else data def _nyc_bike(channel): with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f: data = f["data"][:].astype(np.float32) data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2) return data[:, :, channel:channel + 1] def read_BeijingTaxi(): files = [ "TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy", ] data = np.concatenate( [np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0 ) T = data.shape[0] return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2)