TrafficWheel/dataloader/TSloader.py

192 lines
6.7 KiB
Python
Executable File
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataloader.data_selector import load_st_dataset
from utils.normalization import normalize_dataset
import numpy as np
import torch
def get_dataloader(args, normalizer="std", single=True):
data = load_st_dataset(args)
# data = data[..., 0:1]
args = args["data"]
L, N, F = data.shape
# data = data.reshape(L, N*F) # [L, N*F]
# Generate sliding windows for main data and add time features
x, y = _prepare_data_with_windows(data, args, single)
# Split data [b,t,n,c]
split_fn = split_data_by_days if args["test_ratio"] > 1 else split_data_by_ratio
x_train, x_val, x_test = split_fn(x, args["val_ratio"], args["test_ratio"])
y_train, y_val, y_test = split_fn(y, args["val_ratio"], args["test_ratio"])
# Normalize x and y using the same scaler
scaler = _normalize_data(x_train, x_val, x_test, args, normalizer)
_apply_existing_scaler(y_train, y_val, y_test, scaler, args)
# reshape [b,t,n,c] -> [b*n, t, c]
x_train, x_val, x_test, y_train, y_val, y_test = \
_reshape_tensor(x_train, x_val, x_test, y_train, y_val, y_test)
# Create dataloaders
return (
_create_dataloader(x_train, y_train, args["batch_size"], True, False),
_create_dataloader(x_val, y_val, args["batch_size"], False, False),
_create_dataloader(x_test, y_test, args["batch_size"], False, False),
scaler
)
def _reshape_tensor(*tensors):
"""Reshape tensors from [b, t, n, c] -> [b*n, t, c]."""
reshaped = []
for x in tensors:
# x 是 ndarrayshape (b, t, n, c)
b, t, n, c = x.shape
x_new = x.transpose(0, 2, 1, 3).reshape(b * n, t, c)
reshaped.append(x_new)
return reshaped
def _prepare_data_with_windows(data, args, single):
# Generate sliding windows for main data
x = add_window_x(data, args["lag"], args["horizon"], single)
y = add_window_y(data, args["lag"], args["horizon"], single)
return x, y
def _normalize_data(train_data, val_data, test_data, args, normalizer):
scaler = normalize_dataset(train_data[..., : args["num_nodes"]], normalizer, args["column_wise"])
for data in [train_data, val_data, test_data]:
data[..., : args["num_nodes"]] = scaler.transform(data[..., : args["num_nodes"]])
return scaler
def _apply_existing_scaler(train_data, val_data, test_data, scaler, args):
for data in [train_data, val_data, test_data]:
data[..., : args["num_nodes"]] = scaler.transform(data[..., : args["num_nodes"]])
def _create_dataloader(X_data, Y_data, batch_size, shuffle, drop_last):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_tensor = torch.tensor(X_data, dtype=torch.float32, device=device)
Y_tensor = torch.tensor(Y_data, dtype=torch.float32, device=device)
dataset = torch.utils.data.TensorDataset(X_tensor, Y_tensor)
return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, drop_last=drop_last)
def split_data_by_days(data, val_days, test_days, interval=30):
t = int((24 * 60) / interval)
test_data = data[-t * int(test_days) :]
val_data = data[-t * int(test_days + val_days) : -t * int(test_days)]
train_data = data[: -t * int(test_days + val_days)]
return train_data, val_data, test_data
def split_data_by_ratio(data, val_ratio, test_ratio):
data_len = data.shape[0]
test_data = data[-int(data_len * test_ratio) :]
val_data = data[-int(data_len * (test_ratio + val_ratio)) : -int(data_len * test_ratio)]
train_data = data[: -int(data_len * (test_ratio + val_ratio))]
return train_data, val_data, test_data
def _generate_windows(data, window=3, horizon=1, offset=0):
"""
Internal helper function to generate sliding windows.
:param data: Input data, shape [L, T, C]
:param window: Window size
:param horizon: Horizon size
:param offset: Offset from window start
:return: Windowed data, shape [num_windows, window, T, C]
"""
length = len(data)
end_index = length - horizon - window + 1
windows = []
index = 0
if end_index <= 0:
raise ValueError(f"end_index is non-positive: {end_index}, length={length}, horizon={horizon}, window={window}")
while index < end_index:
window_data = data[index + offset : index + offset + window]
windows.append(window_data)
index += 1
if not windows:
raise ValueError("No windows generated")
# Check window shapes
first_shape = windows[0].shape
for i, w in enumerate(windows):
if w.shape != first_shape:
raise ValueError(f"Window {i} has shape {w.shape}, expected {first_shape}")
return np.array(windows)
def add_window_x(data, window=3, horizon=1, single=False):
"""
Generate windowed X values from the input data.
:param data: Input data, shape [L, T, C]
:param window: Size of the sliding window
:param horizon: Horizon size
:param single: If True, generate single-step windows, else multi-step
:return: X with shape [num_windows, window, T, C]
"""
return _generate_windows(data, window, horizon, offset=0)
def add_window_y(data, window=3, horizon=1, single=False):
"""
Generate windowed Y values from the input data.
:param data: Input data, shape [L, T, C]
:param window: Size of the sliding window
:param horizon: Horizon size
:param single: If True, generate single-step windows, else multi-step
:return: Y with shape [num_windows, horizon, T, C]
"""
offset = window if not single else window + horizon - 1
return _generate_windows(data, window=1 if single else horizon, horizon=horizon, offset=offset)
if __name__ == "__main__":
# Test with a dummy config using METR-LA dataset
dummy_args = {
"basic": {
"dataset": "METR-LA"
},
"data": {
"lag": 3,
"horizon": 1,
"val_ratio": 0.1,
"test_ratio": 0.2,
"steps_per_day": 288,
"days_per_week": 7,
"input_dim": 1,
"column_wise": False,
"batch_size": 32,
"time_dim": 1 # Add time dimension parameter
}
}
try:
# Load data
data = load_st_dataset(dummy_args)
print(f"Original data shape: {data.shape}")
# Get dataloader
train_loader, val_loader, test_loader, scaler = get_dataloader(dummy_args)
# Test data loader
for batch_x, batch_y in train_loader:
print(f"Batch X shape: {batch_x.shape}")
print(f"Batch Y shape: {batch_y.shape}")
break
print("Test passed successfully!")
except Exception as e:
print(f"Test failed with error: {e}")
import traceback
traceback.print_exc()