116 lines
4.0 KiB
Python
116 lines
4.0 KiB
Python
import os
|
|
import json
|
|
import argparse
|
|
import numpy as np
|
|
|
|
|
|
def generate_offsets(seq_length_x: int, seq_length_y: int):
|
|
x_offsets = np.sort(np.concatenate((np.arange(-(seq_length_x - 1), 1, 1),)))
|
|
y_offsets = np.sort(np.arange(1, seq_length_y + 1, 1))
|
|
return x_offsets, y_offsets
|
|
|
|
|
|
def make_sliding_windows(data: np.ndarray, x_offsets: np.ndarray, y_offsets: np.ndarray):
|
|
# data: (T, N, C)
|
|
num_samples = data.shape[0]
|
|
min_t = abs(int(np.min(x_offsets)))
|
|
max_t = num_samples - int(np.max(y_offsets))
|
|
|
|
x, y = [], []
|
|
for t in range(min_t, max_t):
|
|
x.append(data[t + x_offsets, ...]) # (seq_len, N, C)
|
|
y.append(data[t + y_offsets, ...]) # (pred_len, N, C)
|
|
|
|
x = np.stack(x, axis=0).astype(np.float32) # (S, seq_len, N, C)
|
|
y = np.stack(y, axis=0).astype(np.float32) # (S, pred_len, N, C)
|
|
|
|
# Reorder to (S, N, L, C) to match model expectation: b n l m
|
|
x = np.transpose(x, (0, 2, 1, 3))
|
|
y = np.transpose(y, (0, 2, 1, 3))
|
|
return x, y
|
|
|
|
|
|
def split_by_ratio(x: np.ndarray, y: np.ndarray, ratios):
|
|
r_train, r_val, r_test = ratios
|
|
num_samples = x.shape[0]
|
|
n_train = int(round(num_samples * r_train))
|
|
n_val = int(round(num_samples * r_val))
|
|
n_test = num_samples - n_train - n_val
|
|
|
|
x_train, y_train = x[:n_train], y[:n_train]
|
|
x_val, y_val = x[n_train:n_train + n_val], y[n_train:n_train + n_val]
|
|
x_test, y_test = x[-n_test:], y[-n_test:]
|
|
return (x_train, y_train), (x_val, y_val), (x_test, y_test)
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Prepare PEMS-BAY to train/val/test .npz")
|
|
parser.add_argument("--dataset_dir", type=str, default='./datasets/PEMS-BAY', help="Path to datasets/PEMS-BAY directory")
|
|
parser.add_argument("--seq_len", type=int, default=12)
|
|
parser.add_argument("--pred_len", type=int, default=12)
|
|
parser.add_argument("--speed_channel_only", action="store_true", help="Use only the first channel (speed)")
|
|
args = parser.parse_args()
|
|
|
|
dataset_dir = args.dataset_dir
|
|
desc_path = os.path.join(dataset_dir, "desc.json")
|
|
data_path = os.path.join(dataset_dir, "data.dat")
|
|
|
|
if not os.path.exists(desc_path):
|
|
raise FileNotFoundError(f"desc.json not found at {desc_path}")
|
|
if not os.path.exists(data_path):
|
|
raise FileNotFoundError(f"data.dat not found at {data_path}")
|
|
|
|
with open(desc_path, "r") as f:
|
|
desc = json.load(f)
|
|
|
|
shape = desc.get("shape") # expected [T, N, C]
|
|
if not shape or len(shape) not in (2, 3):
|
|
raise ValueError(f"Invalid shape in desc.json: {shape}")
|
|
|
|
total_elems = int(np.prod(shape)) if len(shape) == 3 else int(np.prod(shape) * 1)
|
|
raw = np.fromfile(data_path, dtype=np.float32)
|
|
if raw.size != total_elems:
|
|
# Try infer last dim as 1 if desc has 2 dims
|
|
if len(shape) == 2 and raw.size == shape[0] * shape[1]:
|
|
pass
|
|
else:
|
|
raise ValueError(f"data.dat size mismatch. desc={shape}, fromfile={raw.size}")
|
|
|
|
if len(shape) == 3:
|
|
data = raw.reshape(shape)
|
|
else:
|
|
data = raw.reshape(shape + [1]) # (T, N, 1)
|
|
|
|
# Use only speed channel for this model (expects C=1)
|
|
if data.shape[-1] > 1:
|
|
data = data[..., :1]
|
|
|
|
x_offsets, y_offsets = generate_offsets(args.seq_len, args.pred_len)
|
|
x, y = make_sliding_windows(data, x_offsets, y_offsets)
|
|
|
|
ratios = desc.get("regular_settings", {}).get("TRAIN_VAL_TEST_RATIO", [0.7, 0.1, 0.2])
|
|
(x_train, y_train), (x_val, y_val), (x_test, y_test) = split_by_ratio(x, y, ratios)
|
|
|
|
for split_name, _x, _y in (
|
|
("train", x_train, y_train),
|
|
("val", x_val, y_val),
|
|
("test", x_test, y_test),
|
|
):
|
|
out_path = os.path.join(dataset_dir, f"{split_name}.npz")
|
|
np.savez_compressed(
|
|
out_path,
|
|
x=_x,
|
|
y=_y,
|
|
x_offsets=x_offsets.reshape(list(x_offsets.shape) + [1]),
|
|
y_offsets=y_offsets.reshape(list(y_offsets.shape) + [1]),
|
|
)
|
|
print(f"Saved {split_name} -> {out_path} | x={_x.shape}, y={_y.shape}")
|
|
|
|
print("Done.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|