Add a more data-efficient training data generation method.

This commit is contained in:
Yaguang 2018-06-07 10:41:47 +08:00
parent e93435c598
commit 562514cc30
2 changed files with 64 additions and 0 deletions

View File

@ -156,6 +156,59 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n
return x, y
def generate_graph_seq2seq_io_data_with_time2(df, batch_size, seq_len, horizon, num_nodes, scaler=None,
add_time_in_day=True, add_day_in_week=False):
"""
:param df:
:param batch_size:
:param seq_len:
:param horizon:
:param scaler:
:param add_time_in_day:
:param add_day_in_week:
:return:
x, y, both are 5-D tensors with size (epoch_size, batch_size, seq_len, num_sensors, input_dim).
Adjacent batches are continuous sequence, i.e., x[i, j, :, :] is before x[i+1, j, :, :]
"""
if scaler:
df = scaler.transform(df)
num_samples, _ = df.shape
assert df.shape[1] == num_nodes
data = df.values
data = np.expand_dims(data, axis=-1)
data_list = [data]
if add_time_in_day:
time_ind = (df.index.values - df.index.values.astype('datetime64[D]')) / np.timedelta64(1, 'D')
time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0))
data_list.append(time_in_day)
if add_day_in_week:
day_in_week = np.zeros(shape=(num_samples, num_nodes, 7))
day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1
data_list.append(day_in_week)
# data: (num_samples, num_nodes, num_features)
data = np.concatenate(data_list, axis=-1)
num_features = data.shape[-1]
# Extract x and y
epoch_size = num_samples - seq_len - horizon + 1
x, y = [], []
for i in range(epoch_size):
x_i = data[i: i + seq_len, ...]
y_i = data[i + seq_len: i + seq_len + horizon, ...]
x.append(x_i)
y.append(y_i)
x = np.stack(x, axis=0)
y = np.stack(y, axis=0)
epoch_size //= batch_size
x = x[:batch_size * epoch_size, ...]
y = y[:batch_size * epoch_size, ...]
x = x.reshape(epoch_size, batch_size, seq_len, num_nodes, num_features)
y = y.reshape(epoch_size, batch_size, horizon, num_nodes, num_features)
return x, y
def load_pickle(pickle_file):
try:
with open(pickle_file, 'rb') as f:
@ -168,6 +221,7 @@ def load_pickle(pickle_file):
raise
return pickle_data
def round_down(num, divisor):
return num - (num % divisor)

View File

@ -44,6 +44,16 @@ class IODataPreparationTest(unittest.TestCase):
self.assertTupleEqual(xs.shape, (3, 2, 9))
self.assertTupleEqual(ys.shape, (3, 2, 6))
def test_generate_graph_seq2seq_io_data_with_time(self):
data = np.array([
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
], dtype=np.float32).T
df = pd.DataFrame(data, index=pd.date_range('2017-10-18', '2017-10-19 23:59', freq='3h'))
xs, ys = utils.generate_graph_seq2seq_io_data_with_time2(df, batch_size=2, seq_len=3, horizon=3, num_nodes=2)
self.assertTupleEqual(xs.shape, (5, 2, 3, 2, 2))
self.assertTupleEqual(ys.shape, (5, 2, 3, 2, 2))
class StandardScalerTest(unittest.TestCase):
def test_transform(self):