TrafficWheel/dataloader/cde_loader/add_window.py

38 lines
1.0 KiB
Python
Executable File

import numpy as np
def Add_Window_Horizon(data, window=3, horizon=1, single=False):
'''
:param data: shape [B, ...]
:param window:
:param horizon:
:return: X is [B, W, ...], Y is [B, H, ...]
'''
length = len(data)
end_index = length - horizon - window + 1
X = [] #windows
Y = [] #horizon
index = 0
if single:
while index < end_index:
X.append(data[index:index+window])
Y.append(data[index+window+horizon-1:index+window+horizon])
index = index + 1
else:
while index < end_index:
X.append(data[index:index+window])
Y.append(data[index+window:index+window+horizon])
index = index + 1
X = np.array(X)
Y = np.array(Y)
return X, Y
if __name__ == '__main__':
from data.load_raw_data import Load_Sydney_Demand_Data
path = '../data/1h_data_new3.csv'
data = Load_Sydney_Demand_Data(path)
print(data.shape)
X, Y = Add_Window_Horizon(data, horizon=2)
print(X.shape, Y.shape)