DDMR / Brain_study /split_dataset.py
jpdefrutos's picture
Scripts for training on the IXI T1 MRI Dataset
6a4f823
raw
history blame
2.81 kB
import os
import argparse
import random
import warnings
import math
from shutil import copyfile
from tqdm import tqdm
import concurrent.futures
import numpy as np
def copy_file(s_d):
s, d = s_d
file_name = os.path.split(s)[-1]
copyfile(s, os.path.join(d, file_name))
return int(os.path.exists(d))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--train', '-t', type=float, default=.70, help='Train percentage. Default: 0.70')
parser.add_argument('--validation', '-v', type=float, default=0.15, help='Validation percentage. Default: 0.15')
parser.add_argument('--test', '-s', type=float, default=0.15, help='Test percentage. Default: 0.15')
parser.add_argument('-d', '--dir', type=str, help='Directory where the data is')
parser.add_argument('-f', '--format', type=str, help='Format of the data files. Default: h5', default='h5')
parser.add_argument('-r', '--random', type=bool, help='Randomly split the dataset or not. Default: True', default=True)
args = parser.parse_args()
assert args.train + args.validation + args.test == 1.0, 'Train+Validation+Test != 1 (100%)'
file_set = [os.path.join(args.dir, f) for f in os.listdir(args.dir) if f.endswith(args.format)]
random.shuffle(file_set) if args.random else file_set.sort()
num_files = len(file_set)
num_validation = math.floor(num_files * args.validation)
num_test = math.floor(num_files * args.test)
num_train = num_files - num_test - num_validation
dataset_root, dataset_name = os.path.split(args.dir)
dst_train = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'train_set')
dst_validation = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'validation_set')
dst_test = os.path.join(dataset_root, 'SPLIT_'+dataset_name, 'test_set')
print('OUTPUT INFORMATION\n=============')
print('Train:\t\t{}'.format(num_train))
print('Validation:\t{}'.format(num_validation))
print('Test:\t\t{}'.format(num_test))
print('Num. samples\t{}'.format(num_files))
print('Path:\t\t', os.path.join(dataset_root, 'SPLIT_'+dataset_name))
dest = [dst_train] * num_train + [dst_validation] * num_validation + [dst_test] * num_test
os.makedirs(dst_train, exist_ok=True)
os.makedirs(dst_validation, exist_ok=True)
os.makedirs(dst_test, exist_ok=True)
progress_bar = tqdm(zip(file_set, dest), desc='Copying files', total=num_files)
with concurrent.futures.ProcessPoolExecutor(max_workers=10) as ex:
results = list(tqdm(ex.map(copy_file, zip(file_set, dest)), desc='Copying files', total=num_files))
num_copies = np.sum(results)
if num_copies == num_files:
print('Done successfully')
else:
warnings.warn('Missing files: {}'.format(num_files - num_copies))