TrafficWheel/dataloader/data_selector.py

84 lines
3.4 KiB
Python

import os
import numpy as np
import h5py
def load_st_dataset(config):
dataset = config["basic"]["dataset"]
# sample = config["data"]["sample"]
# output B, N, D
match dataset:
case "BeijingAirQuality":
data_path = os.path.join("./data/BeijingAirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 36000, 7, 3
data = data.reshape(L, N, C)
case "AirQuality":
data_path = os.path.join("./data/AirQuality/data.dat")
data = np.memmap(data_path, dtype=np.float32, mode='r')
L, N, C = 8701,35,6
data = data.reshape(L, N, C)
case "PEMS-BAY":
data_path = os.path.join("./data/PEMS-BAY/pems-bay.h5")
with h5py.File(data_path, 'r') as f:
data = f['speed']['block0_values'][:]
case "METR-LA":
data_path = os.path.join("./data/METR-LA/METR-LA.h5")
with h5py.File(data_path, 'r') as f:
data = f['df']['block0_values'][:]
case "SolarEnergy":
data_path = os.path.join("./data/SolarEnergy/SolarEnergy.csv")
data = np.loadtxt(data_path, delimiter=",")
case "PEMSD3":
data_path = os.path.join("./data/PEMS03/PEMS03.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD4":
data_path = os.path.join("./data/PEMS04/PEMS04.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7":
data_path = os.path.join("./data/PEMS07/PEMS07.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD8":
data_path = os.path.join("./data/PEMS08/PEMS08.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(L)":
data_path = os.path.join("./data/PEMS07(L)/PEMS07L.npz")
data = np.load(data_path)["data"][:, :, 0]
case "PEMSD7(M)":
data_path = os.path.join("./data/PEMS07(M)/V_228.csv")
data = np.genfromtxt(data_path, delimiter=",")
case "BJ":
data_path = os.path.join("./data/BJ/BJ500.csv")
data = np.genfromtxt(data_path, delimiter=",", skip_header=1)
case "Hainan":
data_path = os.path.join("./data/Hainan/Hainan.npz")
data = np.load(data_path)["data"][:, :, 0]
case "SD":
data_path = os.path.join("./data/SD/data.npz")
data = np.load(data_path)["data"][:, :, 0].astype(np.float32)
case "BJTaxi-InFlow":
data = read_BeijingTaxi()[:, :, 0:1].astype(np.float32)
case "BJTaxi-OutFlow":
data = read_BeijingTaxi()[:, :, 1:2].astype(np.float32)
case _:
raise ValueError(f"Unsupported dataset: {dataset}")
# Ensure data shape compatibility
if len(data.shape) == 2:
data = np.expand_dims(data, axis=-1)
print("加载 %s 数据集中... " % dataset)
# return data[::sample]
return data
def read_BeijingTaxi():
files = ["TaxiBJ2013.npy", "TaxiBJ2014.npy", "TaxiBJ2015.npy",
"TaxiBJ2016_1.npy", "TaxiBJ2016_2.npy"]
all_data = []
for file in files:
data_path = os.path.join(f"./data/BeijingTaxi/{file}")
data = np.load(data_path)
all_data.append(data)
all_data = np.concatenate(all_data, axis=0)
time_num = all_data.shape[0]
all_data = all_data.transpose(0, 2, 3, 1).reshape(time_num, 32*32, 2)
return all_data