82 lines
2.6 KiB
Python
82 lines
2.6 KiB
Python
import os
|
|
import numpy as np
|
|
import h5py
|
|
|
|
|
|
def load_st_dataset(config):
|
|
dataset = config["basic"]["dataset"]
|
|
|
|
loaders = {
|
|
"BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3),
|
|
"AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6),
|
|
|
|
"PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")),
|
|
"METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")),
|
|
|
|
"SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","),
|
|
|
|
"PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"),
|
|
"PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"),
|
|
"PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"),
|
|
"PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"),
|
|
|
|
"PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"),
|
|
"PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","),
|
|
|
|
"BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1),
|
|
"Hainan": lambda: _npz("./data/Hainan/Hainan.npz"),
|
|
"SD": lambda: _npz("./data/SD/data.npz", cast=True),
|
|
|
|
"BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32),
|
|
"BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32),
|
|
|
|
"NYCBike-InFlow": lambda: _nyc_bike(0),
|
|
"NYCBike-OutFlow": lambda: _nyc_bike(1),
|
|
}
|
|
|
|
if dataset not in loaders:
|
|
raise ValueError(f"Unsupported dataset: {dataset}")
|
|
|
|
data = loaders[dataset]()
|
|
|
|
if data.ndim == 2:
|
|
data = data[..., None]
|
|
|
|
print(f"加载 {dataset} 数据集中... ")
|
|
return data
|
|
|
|
|
|
# ---------------- helpers ----------------
|
|
def _memmap(path, L, N, C):
|
|
data = np.memmap(path, dtype=np.float32, mode="r")
|
|
return data.reshape(L, N, C)
|
|
|
|
|
|
def _h5(path, keys):
|
|
with h5py.File(path, "r") as f:
|
|
return f[keys[0]][keys[1]][:]
|
|
|
|
|
|
def _npz(path, cast=False):
|
|
data = np.load(path)["data"][:, :, 0]
|
|
return data.astype(np.float32) if cast else data
|
|
|
|
|
|
def _nyc_bike(channel):
|
|
with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f:
|
|
data = f["data"][:].astype(np.float32)
|
|
data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2)
|
|
return data[:, :, channel:channel + 1]
|
|
|
|
|
|
def read_BeijingTaxi():
|
|
files = [
|
|
"TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
|
|
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy",
|
|
]
|
|
data = np.concatenate(
|
|
[np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0
|
|
)
|
|
T = data.shape[0]
|
|
return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2)
|