TrafficWheel/dataloader/cde_loader/load_dataset.py

37 lines
1.9 KiB
Python

import os
import numpy as np
def load_st_dataset(dataset):
#output B, N, D
if dataset == 'PEMSD4':
data_path = os.path.join('../data/PEMS04/PEMS04.npz')
# data_path = os.path.join('../data/PeMSD4/pems04.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'PEMSD8':
data_path = os.path.join('../data/PEMS08/PEMS08.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'PEMSD3':
data_path = os.path.join('../data/PEMS03/PEMS03.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'PEMSD7':
data_path = os.path.join('../data/PEMS07/PEMS07.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'PEMSD7M':
data_path = os.path.join('../data/PEMS07M/PEMS07M.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'PEMSD7L':
data_path = os.path.join('../data/PEMS07L/PEMS07L.npz')
data = np.load(data_path)['data'][:, :, 0] #onley the first dimension, traffic flow data
elif dataset == 'Decentraland':
data_path = os.path.join('../token_data/Decentraland_node_features.npz')
data = np.load(data_path)['arr_0'][:, :, 0] #1 dimension, degree
elif dataset == 'Bytom':
data_path = os.path.join('../token_data/Bytom_node_features.npz')
data = np.load(data_path)['arr_0'][:, :, 0] #1 dimension, degree
else:
raise ValueError
if len(data.shape) == 2:
data = np.expand_dims(data, axis=-1)
print('Load %s Dataset shaped: ' % dataset, data.shape, data.max(), data.min(), data.mean(), np.median(data))
return data