86 lines
3.0 KiB
Python
86 lines
3.0 KiB
Python
import numpy as np
|
|
import math
|
|
|
|
|
|
def batch_iter(data, batch_size, shuffled=True):
|
|
"""
|
|
A batch iteration
|
|
|
|
Arguments:
|
|
data(dict): data
|
|
batch_size (int): the batch size
|
|
shuffled (bool): whether to shuffle the data at the start of each epoch
|
|
:returns: sample index, batch of x, batch_of y
|
|
:rtype: int, ndarray, ndarray
|
|
"""
|
|
|
|
assert 'x' in data and 'y' in data
|
|
data_x = np.asarray(data['x'])
|
|
data_y = np.asarray(data['y'])
|
|
data_size = len(data_y)
|
|
num_batches_per_epoch = math.ceil(data_size / batch_size)
|
|
|
|
while True:
|
|
shuffled_index = np.random.permutation(
|
|
np.arange(data_size)) if shuffled else np.arange(data_size)
|
|
for batch in range(num_batches_per_epoch):
|
|
start_index = batch * batch_size
|
|
end_index = min(data_size, (batch + 1) * batch_size)
|
|
sample_index = shuffled_index[start_index:end_index]
|
|
yield sample_index, data_x[sample_index], data_y[sample_index]
|
|
|
|
|
|
class VerticalDataSampler(object):
|
|
"""
|
|
VerticalDataSampler is used to sample a subset from data
|
|
|
|
Arguments:
|
|
data(dict): data
|
|
replace (bool): Whether the sample is with or without replacement
|
|
"""
|
|
def __init__(self,
|
|
data,
|
|
replace=False,
|
|
use_full_trainset=True,
|
|
feature_frac=1.0):
|
|
assert 'x' in data
|
|
self.data_x = np.asarray(data['x'])
|
|
self.data_y = np.asarray(data['y']) if 'y' in data else None
|
|
self.data_size = self.data_x.shape[0]
|
|
self.feature_size = self.data_x.shape[1]
|
|
self.replace = replace
|
|
self.use_full_trainset = use_full_trainset
|
|
self.selected_feature_num = max(1,
|
|
int(self.feature_size * feature_frac))
|
|
self.selected_feature_index = None
|
|
|
|
def sample_data(self, sample_size, index=None):
|
|
|
|
# use the entire dataset
|
|
if self.use_full_trainset:
|
|
return range(len(self.data_x)), self.data_x, self.data_y
|
|
|
|
if index is not None:
|
|
sampled_x = self.data_x[index]
|
|
sampled_y = self.data_y[index] if self.data_y is not None else None
|
|
else:
|
|
sample_size = min(sample_size, self.data_size)
|
|
index = np.random.choice(a=self.data_size,
|
|
size=sample_size,
|
|
replace=self.replace)
|
|
sampled_x = self.data_x[index]
|
|
sampled_y = self.data_y[index] if self.data_y is not None else None
|
|
|
|
return index, sampled_x, sampled_y
|
|
|
|
def sample_feature(self, x):
|
|
if self.selected_feature_num == self.feature_size:
|
|
return range(x.shape[-1]), x
|
|
else:
|
|
feature_index = np.random.choice(a=self.feature_size,
|
|
size=self.selected_feature_num,
|
|
replace=False)
|
|
self.selected_feature_index = feature_index
|
|
|
|
return feature_index, x[:, feature_index]
|