Add a more data-efficient training data generation method.
This commit is contained in:
parent
e93435c598
commit
562514cc30
54
lib/utils.py
54
lib/utils.py
|
|
@ -156,6 +156,59 @@ def generate_graph_seq2seq_io_data_with_time(df, batch_size, seq_len, horizon, n
|
||||||
return x, y
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
|
def generate_graph_seq2seq_io_data_with_time2(df, batch_size, seq_len, horizon, num_nodes, scaler=None,
|
||||||
|
add_time_in_day=True, add_day_in_week=False):
|
||||||
|
"""
|
||||||
|
|
||||||
|
:param df:
|
||||||
|
:param batch_size:
|
||||||
|
:param seq_len:
|
||||||
|
:param horizon:
|
||||||
|
:param scaler:
|
||||||
|
:param add_time_in_day:
|
||||||
|
:param add_day_in_week:
|
||||||
|
:return:
|
||||||
|
x, y, both are 5-D tensors with size (epoch_size, batch_size, seq_len, num_sensors, input_dim).
|
||||||
|
Adjacent batches are continuous sequence, i.e., x[i, j, :, :] is before x[i+1, j, :, :]
|
||||||
|
"""
|
||||||
|
if scaler:
|
||||||
|
df = scaler.transform(df)
|
||||||
|
num_samples, _ = df.shape
|
||||||
|
assert df.shape[1] == num_nodes
|
||||||
|
data = df.values
|
||||||
|
data = np.expand_dims(data, axis=-1)
|
||||||
|
data_list = [data]
|
||||||
|
if add_time_in_day:
|
||||||
|
time_ind = (df.index.values - df.index.values.astype('datetime64[D]')) / np.timedelta64(1, 'D')
|
||||||
|
time_in_day = np.tile(time_ind, [1, num_nodes, 1]).transpose((2, 1, 0))
|
||||||
|
data_list.append(time_in_day)
|
||||||
|
if add_day_in_week:
|
||||||
|
day_in_week = np.zeros(shape=(num_samples, num_nodes, 7))
|
||||||
|
day_in_week[np.arange(num_samples), :, df.index.dayofweek] = 1
|
||||||
|
data_list.append(day_in_week)
|
||||||
|
|
||||||
|
# data: (num_samples, num_nodes, num_features)
|
||||||
|
data = np.concatenate(data_list, axis=-1)
|
||||||
|
num_features = data.shape[-1]
|
||||||
|
|
||||||
|
# Extract x and y
|
||||||
|
epoch_size = num_samples - seq_len - horizon + 1
|
||||||
|
x, y = [], []
|
||||||
|
for i in range(epoch_size):
|
||||||
|
x_i = data[i: i + seq_len, ...]
|
||||||
|
y_i = data[i + seq_len: i + seq_len + horizon, ...]
|
||||||
|
x.append(x_i)
|
||||||
|
y.append(y_i)
|
||||||
|
x = np.stack(x, axis=0)
|
||||||
|
y = np.stack(y, axis=0)
|
||||||
|
epoch_size //= batch_size
|
||||||
|
x = x[:batch_size * epoch_size, ...]
|
||||||
|
y = y[:batch_size * epoch_size, ...]
|
||||||
|
x = x.reshape(epoch_size, batch_size, seq_len, num_nodes, num_features)
|
||||||
|
y = y.reshape(epoch_size, batch_size, horizon, num_nodes, num_features)
|
||||||
|
return x, y
|
||||||
|
|
||||||
|
|
||||||
def load_pickle(pickle_file):
|
def load_pickle(pickle_file):
|
||||||
try:
|
try:
|
||||||
with open(pickle_file, 'rb') as f:
|
with open(pickle_file, 'rb') as f:
|
||||||
|
|
@ -168,6 +221,7 @@ def load_pickle(pickle_file):
|
||||||
raise
|
raise
|
||||||
return pickle_data
|
return pickle_data
|
||||||
|
|
||||||
|
|
||||||
def round_down(num, divisor):
|
def round_down(num, divisor):
|
||||||
return num - (num % divisor)
|
return num - (num % divisor)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,16 @@ class IODataPreparationTest(unittest.TestCase):
|
||||||
self.assertTupleEqual(xs.shape, (3, 2, 9))
|
self.assertTupleEqual(xs.shape, (3, 2, 9))
|
||||||
self.assertTupleEqual(ys.shape, (3, 2, 6))
|
self.assertTupleEqual(ys.shape, (3, 2, 6))
|
||||||
|
|
||||||
|
def test_generate_graph_seq2seq_io_data_with_time(self):
|
||||||
|
data = np.array([
|
||||||
|
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15],
|
||||||
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16]
|
||||||
|
], dtype=np.float32).T
|
||||||
|
df = pd.DataFrame(data, index=pd.date_range('2017-10-18', '2017-10-19 23:59', freq='3h'))
|
||||||
|
xs, ys = utils.generate_graph_seq2seq_io_data_with_time2(df, batch_size=2, seq_len=3, horizon=3, num_nodes=2)
|
||||||
|
self.assertTupleEqual(xs.shape, (5, 2, 3, 2, 2))
|
||||||
|
self.assertTupleEqual(ys.shape, (5, 2, 3, 2, 2))
|
||||||
|
|
||||||
|
|
||||||
class StandardScalerTest(unittest.TestCase):
|
class StandardScalerTest(unittest.TestCase):
|
||||||
def test_transform(self):
|
def test_transform(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue