File size: 5,433 Bytes
a0ae4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import argparse
import glob
import json
import logging
import os
import sys

import monai
from sklearn.model_selection import train_test_split


def produce_datalist_splits(datalist, splits: list = None, train_split: float = 0.80, valid_test_split: float = 0.50):
    """
    This function is used to split the dataset.
    It will produce "train_size" number of samples for training.
    """
    if splits is None:
        splits = ["test"]
    if "train" in splits:
        train_list, other_list = train_test_split(datalist, train_size=train_split)
        if "valid" in splits:
            val_list, test_list = train_test_split(other_list, train_size=valid_test_split)
            return {"training": train_list, "validation": val_list, "testing": test_list}
        else:
            return {"training": train_list, "testing": other_list}
    elif "valid" in splits:
        val_list, test_list = train_test_split(datalist, train_size=valid_test_split)
        return {"validation": val_list, "testing": test_list}
    else:
        return {"testing": datalist}


def keep_image_label_pairs_only(a_images, a_labels, i_folder, l_folder):
    image_names = [a.split("/")[-1] for a in a_images]
    label_names = [a.split("/")[-1] for a in a_labels]
    # Check if all_labels == all_images, if all_images < all_labels, truncate all_labels
    # image_set = set(image_names)
    # label_set = set(label_names)
    # labelmissing = image_set.difference(label_set)
    # Find names labels not in images
    # imagemissing = label_set.difference(image_set)
    # print('Data_path: ', a_images[0])
    # print('Data folder: ',a_images[0].split('/')[-2])
    # print('Labels missing for: ', len(labelmissing))
    # print('Images missing for: ', len(imagemissing))
    a_images = sorted([os.path.join(i_folder, a) for a in image_names if a in label_names])
    # Keep only labels that have a scan
    image_names = [a.split("/")[-1] for a in a_images]
    a_labels = sorted([os.path.join(l_folder, a) for a in label_names if a in image_names])
    return a_images, a_labels


def parse_files(images_folder, labels_folder, file_extension_pattern):
    logging.info(f"parsing files at: {os.path.join(images_folder, file_extension_pattern)}")
    all_images = sorted(glob.glob(os.path.join(images_folder, file_extension_pattern)))
    all_labels = sorted(glob.glob(os.path.join(labels_folder, file_extension_pattern)))
    return all_images, all_labels


def get_datalist(args, images_folder, labels_folder):
    file_extension_pattern = "*" + args.file_extension + "*"
    if type(images_folder) is list:
        all_images = []
        all_labels = []
        for ifolder, lfolder in zip(images_folder, labels_folder):
            a_images, a_labels = parse_files(ifolder, lfolder, file_extension_pattern)
            a_images, a_labels = keep_image_label_pairs_only(a_images, a_labels, ifolder, lfolder)
            all_images += a_images
            all_labels += a_labels
    else:
        all_images, all_labels = parse_files(images_folder, labels_folder, file_extension_pattern)
        all_images, all_labels = keep_image_label_pairs_only(all_images, all_labels, images_folder, labels_folder)

    logging.info("Length of all_images: {}".format(len(all_images)))
    logging.info("Length of all_labels: {}".format(len(all_labels)))

    datalist = [{"image": image_name, "label": label_name} for image_name, label_name in zip(all_images, all_labels)]

    # datalist = datalist[0 : args.limit] if args.limit else datalist
    logging.info(f"datalist length is {len(datalist)}")
    return datalist


def main(args):
    """
    split the dataset and output the data list into a json file.
    """
    data_file_base_dir = args.path
    output_json = args.output
    # produce deterministic data splits
    monai.utils.set_determinism(seed=123)
    datalist = get_datalist(args, data_file_base_dir, os.path.join(data_file_base_dir, args.labels_folder))
    datalist = produce_datalist_splits(datalist, args.splits, args.train_split, args.valid_test_split)
    with open(output_json, "w") as f:
        json.dump(datalist, f, ensure_ascii=True, indent=4)
    logging.info("datalist json file saved to: {}".format(output_json))


if __name__ == "__main__":
    logging.basicConfig(
        stream=sys.stdout,
        level=logging.DEBUG,
        format="[%(asctime)s.%(msecs)03d][%(levelname)5s](%(name)s) - %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    parser = argparse.ArgumentParser(description="")
    parser.add_argument(
        "--path",
        type=str,
        default="/workspace/data/msd/Task07_Pancreas",
        help="root path of MSD Task07_Pancreas dataset.",
    )
    parser.add_argument(
        "--output", type=str, default="dataset_0.json", help="relative path of output datalist json file."
    )
    parser.add_argument("--train_split", type=int, default=0.80, help="fraction of Training samples.")
    parser.add_argument("--valid_test_split", type=int, default=0.50, help="fraction of valid/test samples.")
    parser.add_argument("--splits", type=list, default=["test"], help="splits to use for train, valid, and test.")
    parser.add_argument("--file_extension", type=str, default="nii", help="file extension of images and labels.")
    parser.add_argument("--labels_folder", type=str, default="labels/final", help="labels sub folder name")

    args = parser.parse_args()

    main(args)