diff --git a/lib/utils.py b/lib/utils.py index 38c12fa..20eb52f 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -156,6 +156,59 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n return x, y +def generate_graph_seq2seq_io_data_with_time2(df, batch_size, seq_len, horizon, num_nodes, scaler=None, + add_time_in_day=True, add_day_in_week=False): + """ + + :param df: + :param batch_size: + :param seq_len: + :param horizon: + :param scaler: + :param add_time_in_day: + :param add_day_in_week: + :return: + x, y, both are 5-D tensors with size (epoch_size, batch_size, seq_len, num_sensors, input_dim). + Adjacent batches are continuous sequence, i.e., x[i, j, :, :] is before x[i+1, j, :, :] + """ + if scaler: + df = scaler.transform(df) + num_samples, _ = df.shape + assert df.shape[1] == num_nodes + data = df.values + data = np.expand_dims(data, axis=-1) + data_list = [data] + if add_time_in_day: + time_ind = (df.index.values - df.index.values.astype('datetime64[D]')) / np.timedelta64(1, 'D') + time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0)) + data_list.append(time_in_day) + if add_day_in_week: + day_in_week = np.zeros(shape=(num_samples, num_nodes, 7)) + day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1 + data_list.append(day_in_week) + + # data: (num_samples, num_nodes, num_features) + data = np.concatenate(data_list, axis=-1) + num_features = data.shape[-1] + + # Extract x and y + epoch_size = num_samples - seq_len - horizon + 1 + x, y = [], [] + for i in range(epoch_size): + x_i = data[i: i + seq_len, ...] + y_i = data[i + seq_len: i + seq_len + horizon, ...] + x.append(x_i) + y.append(y_i) + x = np.stack(x, axis=0) + y = np.stack(y, axis=0) + epoch_size //= batch_size + x = x[:batch_size * epoch_size, ...] + y = y[:batch_size * epoch_size, ...] + x = x.reshape(epoch_size, batch_size, seq_len, num_nodes, num_features) + y = y.reshape(epoch_size, batch_size, horizon, num_nodes, num_features) + return x, y + + def load_pickle(pickle_file): try: with open(pickle_file, 'rb') as f: @@ -168,6 +221,7 @@ def load_pickle(pickle_file): raise return pickle_data + def round_down(num, divisor): return num - (num % divisor) diff --git a/lib/utils_test.py b/lib/utils_test.py index 301474a..b97b733 100644 --- a/lib/utils_test.py +++ b/lib/utils_test.py @@ -44,6 +44,16 @@ class IODataPreparationTest(unittest.TestCase): self.assertTupleEqual(xs.shape, (3, 2, 9)) self.assertTupleEqual(ys.shape, (3, 2, 6)) + def test_generate_graph_seq2seq_io_data_with_time(self): + data = np.array([ + [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] + ], dtype=np.float32).T + df = pd.DataFrame(data, index=pd.date_range('2017-10-18', '2017-10-19 23:59', freq='3h')) + xs, ys = utils.generate_graph_seq2seq_io_data_with_time2(df, batch_size=2, seq_len=3, horizon=3, num_nodes=2) + self.assertTupleEqual(xs.shape, (5, 2, 3, 2, 2)) + self.assertTupleEqual(ys.shape, (5, 2, 3, 2, 2)) + class StandardScalerTest(unittest.TestCase): def test_transform(self):