import os import numpy as np import h5py def load_st_dataset(config): dataset = config["basic"]["dataset"] # sample = config["data"]["sample"] # output B, N, D match dataset: case "BeijingAirQuality": data_path = os.path.join("./data/BeijingAirQuality/data.dat") data = np.memmap(data_path, dtype=np.float32, mode='r') L, N, C = 36000, 7, 3 data = data.reshape(L, N, C) case "AirQuality": data_path = os.path.join("./data/AirQuality/data.dat") data = np.memmap(data_path, dtype=np.float32, mode='r') L, N, C = 8701,35,6 data = data.reshape(L, N, C) case "PEMS-BAY": data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5") with h5py.File(data_path, 'r') as f: data = f['speed']['block0_values'][:] case "METR-LA": data_path = os.path.join("./data/METR-LA/METR-LA.h5") with h5py.File(data_path, 'r') as f: data = f['df']['block0_values'][:] case "SolarEnergy": data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv") data = np.loadtxt(data_path, delimiter=",") case "PEMSD3": data_path = os.path.join("./data/PEMS03/PEMS03.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD4": data_path = os.path.join("./data/PEMS04/PEMS04.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD7": data_path = os.path.join("./data/PEMS07/PEMS07.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD8": data_path = os.path.join("./data/PEMS08/PEMS08.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(L)": data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz") data = np.load(data_path)["data"][:, :, 0] case "PEMSD7(M)": data_path = os.path.join("./data/PEMS07(M)/V_228.csv") data = np.genfromtxt(data_path, delimiter=",") case "BJ": data_path = os.path.join("./data/BJ/BJ500.csv") data = np.genfromtxt(data_path, delimiter=",", skip_header=1) case "Hainan": data_path = os.path.join("./data/Hainan/Hainan.npz") data = np.load(data_path)["data"][:, :, 0] case "SD": data_path = os.path.join("./data/SD/data.npz") data = np.load(data_path)["data"][:, :, 0].astype(np.float32) case "BJTaxi-InFlow": data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32) case "BJTaxi-OutFlow": data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32) case _: raise ValueError(f"Unsupported dataset: {dataset}") # Ensure data shape compatibility if len(data.shape) == 2: data = np.expand_dims(data, axis=-1) print("加载 %s 数据集中... " % dataset) # return data[::sample] return data def read_BeijingTaxi(): files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy", "TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"] all_data = [] for file in files: data_path = os.path.join(f"./data/BeijingTaxi/{file}") data = np.load(data_path) all_data.append(data) all_data = np.concatenate(all_data, axis=0) time_num = all_data.shape[0] all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2) return all_data