File size: 3,154 Bytes
ea5f6fe
 
 
 
 
 
 
410a3d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea5f6fe
3cc12b1
410a3d6
3cc12b1
ea5f6fe
410a3d6
ea5f6fe
410a3d6
 
ea5f6fe
 
410a3d6
 
 
 
 
 
 
 
 
 
ea5f6fe
3cc12b1
410a3d6
 
ea5f6fe
410a3d6
ea5f6fe
e2c51cc
410a3d6
 
 
 
92b2916
410a3d6
6cb8ba2
410a3d6
 
 
6cb8ba2
410a3d6
 
e2c51cc
 
ea5f6fe
e2c51cc
 
 
 
92b2916
e2c51cc
ea5f6fe
e2c51cc
410a3d6
e2c51cc
ea5f6fe
e2c51cc
410a3d6
e2c51cc
3cc12b1
410a3d6
 
 
 
 
ea5f6fe
 
 
 
 
410a3d6
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
# Modified based on https://github.com/akarazniewicz/cocosplit/blob/master/cocosplit.py

import json
import argparse
import funcy
from sklearn.model_selection import train_test_split

parser = argparse.ArgumentParser(
    description="Splits COCO annotations file into training and test sets."
)
parser.add_argument(
    "--annotation-path",
    metavar="coco_annotations",
    type=str,
    help="Path to COCO annotations file.",
)
parser.add_argument(
    "--train", type=str, help="Where to store COCO training annotations"
)
parser.add_argument("--test", type=str, help="Where to store COCO test annotations")
parser.add_argument(
    "--split-ratio",
    dest="split_ratio",
    type=float,
    required=True,
    help="A percentage of a split; a number in (0, 1)",
)
parser.add_argument(
    "--having-annotations",
    dest="having_annotations",
    action="store_true",
    help="Ignore all images without annotations. Keep only these with at least one annotation",
)


def save_coco(file, tagged_data):
    with open(file, "wt", encoding="UTF-8") as coco:
        json.dump(tagged_data, coco, indent=2, sort_keys=True)


def filter_annotations(annotations, images):
    image_ids = funcy.lmap(lambda i: int(i["id"]), images)
    return funcy.lfilter(lambda a: int(a["image_id"]) in image_ids, annotations)


def main(
    annotation_path,
    split_ratio,
    having_annotations,
    train_save_path,
    test_save_path,
    random_state=None,
):

    with open(annotation_path, "rt", encoding="UTF-8") as annotations:
        coco = json.load(annotations)

    images = coco["images"]
    annotations = coco["annotations"]

    ids_with_annotations = funcy.lmap(lambda a: int(a["image_id"]), annotations)

    # Images with annotations
    img_ann = funcy.lremove(lambda i: i["id"] not in ids_with_annotations, images)
    tr_ann, ts_ann = train_test_split(
        img_ann, train_size=split_ratio, random_state=random_state
    )

    img_wo_ann = funcy.lremove(lambda i: i["id"] in ids_with_annotations, images)
    if len(img_wo_ann) > 0:
        tr_wo_ann, ts_wo_ann = train_test_split(
            img_wo_ann, train_size=split_ratio, random_state=random_state
        )
    else:
        tr_wo_ann, ts_wo_ann = [], []  # Images without annotations

    if having_annotations:
        tr, ts = tr_ann, ts_ann

    else:
        # Merging the 2 image lists (i.e. with and without annotation)
        tr_ann.extend(tr_wo_ann)
        ts_ann.extend(ts_wo_ann)

        tr, ts = tr_ann, ts_ann

    # Train Data
    coco.update({"images": tr, "annotations": filter_annotations(annotations, tr)})
    save_coco(train_save_path, coco)

    # Test Data
    coco.update({"images": ts, "annotations": filter_annotations(annotations, ts)})
    save_coco(test_save_path, coco)

    print(
        "Saved {} entries in {} and {} in {}".format(
            len(tr), train_save_path, len(ts), test_save_path
        )
    )


if __name__ == "__main__":
    args = parser.parse_args()

    main(
        args.annotation_path,
        args.split_ratio,
        args.having_annotations,
        args.train,
        args.test,
        random_state=24,
    )