REPST/scripts/data_preparation/BLAST/merge_data.py

46 lines
1.4 KiB
Python

import os
from argparse import ArgumentParser
import numpy as np
from tqdm import tqdm
data_dir_path = 'datasets/BLAST/train'
def main(clean_cache=False):
num_samples = 0
for i in range(99):
shape = tuple(np.load(data_dir_path + f'/shape_{i}_99.npy'))
N, L = shape
num_samples += N
merged_data = np.memmap(data_dir_path + '/data.dat', mode='w+', dtype=np.float32, shape=(num_samples, L))
print('Merging data...')
current_index = 0
for i in tqdm(range(99)):
shape = tuple(np.load(data_dir_path + f'/shape_{i}_99.npy'))
data = np.memmap(data_dir_path + f'/data_{i}_99.dat', mode='r', dtype=np.float32, shape=shape)
merged_data[current_index:current_index + shape[0]] = data
current_index += shape[0]
shape = merged_data.shape
np.save(data_dir_path + '/shape.npy', shape)
print('Data merged successfully.')
if clean_cache:
print('Cleaning cache...')
for i in tqdm(range(99)):
os.remove(data_dir_path + f'/data_{i}_99.dat')
os.remove(data_dir_path + f'/shape_{i}_99.npy')
print('Cache cleaned.')
def parse_args():
parser = ArgumentParser(description='Merge data files into a single memmap file.')
parser.add_argument('--clean_cache', default=True, help='Clean cache after merging.')
return parser.parse_args()
if __name__ == '__main__':
args = parse_args()
main(clean_cache=args.clean_cache)