File size: 3,677 Bytes
6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 6a4f823 286a978 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
import argparse
import random
import warnings
import math
from shutil import copyfile, move
from tqdm import tqdm
import concurrent.futures
import numpy as np
def copy_file_fnc(s_d):
s, d = s_d
file_name = os.path.split(s)[-1]
copyfile(s, os.path.join(d, file_name))
return int(os.path.exists(d))
def move_file_fnc(s_d):
s, d = s_d
file_name = os.path.split(s)[-1]
move(s, os.path.join(d, file_name))
return int(os.path.exists(d))
def split(train_perc: float=0.7,
validation_perc: float=0.15,
test_perc: float=0.15,
data_dir: str='',
file_format: str='h5',
random_split: bool=True,
move_files: bool=False):
assert train_perc + validation_perc + test_perc == 1.0, 'Train+Validation+Test != 1 (100%)'
assert train_perc > 0 and test_perc > 0, 'Train and test percentages must be greater than zero'
file_set = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(file_format)]
random.shuffle(file_set) if random_split else file_set.sort()
num_files = len(file_set)
num_validation = math.floor(num_files * validation_perc)
num_test = math.floor(num_files * test_perc)
num_train = num_files - num_test - num_validation
dataset_root, dataset_name = os.path.split(data_dir)
dst_train = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'train_set')
dst_validation = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'validation_set')
dst_test = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'test_set')
print('OUTPUT INFORMATION\n=============')
print('Train:\t\t{}'.format(num_train))
print('Validation:\t{}'.format(num_validation))
print('Test:\t\t{}'.format(num_test))
print('Num. samples\t{}'.format(num_files))
print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_' + dataset_name))
dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
os.makedirs(dst_train, exist_ok=True)
os.makedirs(dst_validation, exist_ok=True)
os.makedirs(dst_test, exist_ok=True)
progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
operation = move_file_fnc if move_files else copy_file_fnc
desc = 'Moving files' if move_files else 'Copying files'
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
results = list(tqdm(ex.map(operation, zip(file_set, dest)), desc=desc, total=num_files))
num_copies = np.sum(results)
if num_copies == num_files:
print('Done successfully')
else:
warnings.warn('Missing files: {}'.format(num_files - num_copies))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
parser.add_argument('-r', '--random', help='Randomly split the dataset or not. Default: True', action='store_true', default=True)
parser.add_argument('-m', '--movefiles', help='Move files. Otherwise copy. Default: False', action='store_true', default=False)
args = parser.parse_args()
split(args.train, args.validation, args.test, args.dir, args.format, args.random, args.movefiles)
|