TrafficWheel/dataloader/data_selector.py

82 lines
2.6 KiB
Python

import os
import numpy as np
import h5py
def load_st_dataset(config):
dataset = config["basic"]["dataset"]
loaders = {
"BeijingAirQuality": lambda: _memmap("./data/BeijingAirQuality/data.dat", 36000, 7, 3),
"AirQuality": lambda: _memmap("./data/AirQuality/data.dat", 8701, 35, 6),
"PEMS-BAY": lambda: _h5("./data/PEMS-BAY/pems-bay.h5", ("speed", "block0_values")),
"METR-LA": lambda: _h5("./data/METR-LA/METR-LA.h5", ("df", "block0_values")),
"SolarEnergy": lambda: np.loadtxt("./data/SolarEnergy/SolarEnergy.csv", delimiter=","),
"PEMSD3": lambda: _npz("./data/PEMS03/PEMS03.npz"),
"PEMSD4": lambda: _npz("./data/PEMS04/PEMS04.npz"),
"PEMSD7": lambda: _npz("./data/PEMS07/PEMS07.npz"),
"PEMSD8": lambda: _npz("./data/PEMS08/PEMS08.npz"),
"PEMSD7(L)": lambda: _npz("./data/PEMS07(L)/PEMS07L.npz"),
"PEMSD7(M)": lambda: np.genfromtxt("./data/PEMS07(M)/V_228.csv", delimiter=","),
"BJ": lambda: np.genfromtxt("./data/BJ/BJ500.csv", delimiter=",", skip_header=1),
"Hainan": lambda: _npz("./data/Hainan/Hainan.npz"),
"SD": lambda: _npz("./data/SD/data.npz", cast=True),
"BJTaxi-InFlow": lambda: read_BeijingTaxi()[:, :, 0:1].astype(np.float32),
"BJTaxi-OutFlow": lambda: read_BeijingTaxi()[:, :, 1:2].astype(np.float32),
"NYCBike-InFlow": lambda: _nyc_bike(0),
"NYCBike-OutFlow": lambda: _nyc_bike(1),
}
if dataset not in loaders:
raise ValueError(f"Unsupported dataset: {dataset}")
data = loaders[dataset]()
if data.ndim == 2:
data = data[..., None]
print(f"加载 {dataset} 数据集中... ")
return data
# ---------------- helpers ----------------
def _memmap(path, L, N, C):
data = np.memmap(path, dtype=np.float32, mode="r")
return data.reshape(L, N, C)
def _h5(path, keys):
with h5py.File(path, "r") as f:
return f[keys[0]][keys[1]][:]
def _npz(path, cast=False):
data = np.load(path)["data"][:, :, 0]
return data.astype(np.float32) if cast else data
def _nyc_bike(channel):
with h5py.File("./data/NYCBike/NYC16x8.h5", "r") as f:
data = f["data"][:].astype(np.float32)
data = data.transpose(0, 2, 3, 1).reshape(-1, 16 * 8, 2)
return data[:, :, channel:channel + 1]
def read_BeijingTaxi():
files = [
"TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy",
]
data = np.concatenate(
[np.load(f"./data/BeijingTaxi/{f}") for f in files], axis=0
)
T = data.shape[0]
return data.transpose(0, 2, 3, 1).reshape(T, 32 * 32, 2)