38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
import numpy as np
|
|
|
|
|
|
def add_window_horizon(data, window=3, horizon=1, single=False):
|
|
"""
|
|
:param data: shape [B, ...]
|
|
:param window:
|
|
:param horizon:
|
|
:param single:
|
|
:return: X is [B, W, ...], Y is [B, H, ...]
|
|
"""
|
|
length = len(data)
|
|
end_index = length - horizon - window + 1
|
|
x = [] # windows
|
|
y = [] # horizon
|
|
index = 0
|
|
if single:
|
|
while index < end_index:
|
|
x.append(data[index:index + window])
|
|
y.append(data[index + window + horizon - 1:index + window + horizon])
|
|
index = index + 1
|
|
else:
|
|
while index < end_index:
|
|
x.append(data[index:index + window])
|
|
y.append(data[index + window:index + window + horizon])
|
|
index = index + 1
|
|
x = np.array(x)
|
|
y = np.array(y)
|
|
return x, y
|
|
|
|
# if __name__ == '__main__':
|
|
# from data.load_raw_data import Load_Sydney_Demand_Data
|
|
# path = '../data/1h_data_new3.csv'
|
|
# data = Load_Sydney_Demand_Data(path)
|
|
# print(data.shape)
|
|
# X, Y = Add_Window_Horizon(data, horizon=2)
|
|
# print(X.shape, Y.shape)
|