File size: 3,677 Bytes
6a4f823
 
 
 
 
 
286a978
6a4f823
 
 
 
 
286a978
6a4f823
 
 
 
 
 
286a978
 
 
 
 
6a4f823
 
286a978
 
 
 
 
 
 
 
 
6a4f823
286a978
 
6a4f823
 
286a978
 
6a4f823
 
286a978
 
 
 
6a4f823
 
 
 
 
 
286a978
6a4f823
 
 
 
 
 
 
 
286a978
 
6a4f823
286a978
6a4f823
 
 
 
 
 
 
286a978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os
import argparse
import random
import warnings

import math
from shutil import copyfile, move
from tqdm import tqdm
import concurrent.futures
import numpy as np


def copy_file_fnc(s_d):
    s, d = s_d
    file_name = os.path.split(s)[-1]
    copyfile(s, os.path.join(d, file_name))
    return int(os.path.exists(d))


def move_file_fnc(s_d):
    s, d = s_d
    file_name = os.path.split(s)[-1]
    move(s, os.path.join(d, file_name))
    return int(os.path.exists(d))


def split(train_perc: float=0.7,
          validation_perc: float=0.15,
          test_perc: float=0.15,
          data_dir: str='',
          file_format: str='h5',
          random_split: bool=True,
          move_files: bool=False):
    assert train_perc + validation_perc + test_perc == 1.0, 'Train+Validation+Test != 1 (100%)'
    assert train_perc > 0 and test_perc > 0, 'Train and test percentages must be greater than zero'

    file_set = [os.path.join(data_dir, f) for f in os.listdir(data_dir) if f.endswith(file_format)]
    random.shuffle(file_set) if random_split else file_set.sort()

    num_files = len(file_set)
    num_validation = math.floor(num_files * validation_perc)
    num_test = math.floor(num_files * test_perc)
    num_train = num_files - num_test - num_validation

    dataset_root, dataset_name = os.path.split(data_dir)
    dst_train = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'train_set')
    dst_validation = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'validation_set')
    dst_test = os.path.join(dataset_root, 'SPLIT_' + dataset_name, 'test_set')

    print('OUTPUT INFORMATION\n=============')
    print('Train:\t\t{}'.format(num_train))
    print('Validation:\t{}'.format(num_validation))
    print('Test:\t\t{}'.format(num_test))
    print('Num. samples\t{}'.format(num_files))
    print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_' + dataset_name))

    dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test

    os.makedirs(dst_train, exist_ok=True)
    os.makedirs(dst_validation, exist_ok=True)
    os.makedirs(dst_test, exist_ok=True)

    progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
    operation = move_file_fnc if move_files else copy_file_fnc
    desc = 'Moving files' if move_files else 'Copying files'
    with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
        results = list(tqdm(ex.map(operation, zip(file_set, dest)), desc=desc, total=num_files))

    num_copies = np.sum(results)
    if num_copies == num_files:
        print('Done successfully')
    else:
        warnings.warn('Missing files: {}'.format(num_files - num_copies))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
    parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
    parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
    parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
    parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
    parser.add_argument('-r', '--random', help='Randomly split the dataset or not. Default: True', action='store_true', default=True)
    parser.add_argument('-m', '--movefiles', help='Move files. Otherwise copy. Default: False', action='store_true', default=False)

    args = parser.parse_args()

    split(args.train, args.validation, args.test, args.dir, args.format, args.random, args.movefiles)