FS-TFP/federatedscope/trafficflow/dataset/add_window.py

38 lines
1.1 KiB
Python

import numpy as np
def add_window_horizon(data, window=3, horizon=1, single=False):
"""
:param data: shape [B, ...]
:param window:
:param horizon:
:param single:
:return: X is [B, W, ...], Y is [B, H, ...]
"""
length = len(data)
end_index = length - horizon - window + 1
x = [] # windows
y = [] # horizon
index = 0
if single:
while index < end_index:
x.append(data[index:index + window])
y.append(data[index + window + horizon - 1:index + window + horizon])
index = index + 1
else:
while index < end_index:
x.append(data[index:index + window])
y.append(data[index + window:index + window + horizon])
index = index + 1
x = np.array(x)
y = np.array(y)
return x, y
# if __name__ == '__main__':
# from data.load_raw_data import Load_Sydney_Demand_Data
# path = '../data/1h_data_new3.csv'
# data = Load_Sydney_Demand_Data(path)
# print(data.shape)
# X, Y = Add_Window_Horizon(data, horizon=2)
# print(X.shape, Y.shape)