TrafficWheel/dataloader/cde_loader/add_window.py

39 lines
1.1 KiB
Python
Executable File

import numpy as np
def Add_Window_Horizon(data, window=3, horizon=1, single=False):
"""
:param data: shape [B, ...]
:param window:
:param horizon:
:return: X is [B, W, ...], Y is [B, H, ...]
"""
length = len(data)
end_index = length - horizon - window + 1
X = [] # windows
Y = [] # horizon
index = 0
if single:
while index < end_index:
X.append(data[index : index + window])
Y.append(data[index + window + horizon - 1 : index + window + horizon])
index = index + 1
else:
while index < end_index:
X.append(data[index : index + window])
Y.append(data[index + window : index + window + horizon])
index = index + 1
X = np.array(X)
Y = np.array(Y)
return X, Y
if __name__ == "__main__":
from data.load_raw_data import Load_Sydney_Demand_Data
path = "../data/1h_data_new3.csv"
data = Load_Sydney_Demand_Data(path)
print(data.shape)
X, Y = Add_Window_Horizon(data, horizon=2)
print(X.shape, Y.shape)