mckabue commited on
Commit
43a2b9d
·
verified ·
1 Parent(s): ad8a3c2

RE_UPLOAD-REBUILD-RESTART

Browse files
model/layout-model-training/utils/cocosplit.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified based on https://github.com/akarazniewicz/cocosplit/blob/master/cocosplit.py
2
+
3
+ import json
4
+ import argparse
5
+ import funcy
6
+ from sklearn.model_selection import train_test_split
7
+
8
+ parser = argparse.ArgumentParser(
9
+ description="Splits COCO annotations file into training and test sets."
10
+ )
11
+ parser.add_argument(
12
+ "--annotation-path",
13
+ metavar="coco_annotations",
14
+ type=str,
15
+ help="Path to COCO annotations file.",
16
+ )
17
+ parser.add_argument(
18
+ "--train", type=str, help="Where to store COCO training annotations"
19
+ )
20
+ parser.add_argument("--test", type=str, help="Where to store COCO test annotations")
21
+ parser.add_argument(
22
+ "--split-ratio",
23
+ dest="split_ratio",
24
+ type=float,
25
+ required=True,
26
+ help="A percentage of a split; a number in (0, 1)",
27
+ )
28
+ parser.add_argument(
29
+ "--having-annotations",
30
+ dest="having_annotations",
31
+ action="store_true",
32
+ help="Ignore all images without annotations. Keep only these with at least one annotation",
33
+ )
34
+
35
+
36
+ def save_coco(file, tagged_data):
37
+ with open(file, "wt", encoding="UTF-8") as coco:
38
+ json.dump(tagged_data, coco, indent=2, sort_keys=True)
39
+
40
+
41
+ def filter_annotations(annotations, images):
42
+ image_ids = funcy.lmap(lambda i: int(i["id"]), images)
43
+ return funcy.lfilter(lambda a: int(a["image_id"]) in image_ids, annotations)
44
+
45
+
46
+ def main(
47
+ annotation_path,
48
+ split_ratio,
49
+ having_annotations,
50
+ train_save_path,
51
+ test_save_path,
52
+ random_state=None,
53
+ ):
54
+
55
+ with open(annotation_path, "rt", encoding="UTF-8") as annotations:
56
+ coco = json.load(annotations)
57
+
58
+ images = coco["images"]
59
+ annotations = coco["annotations"]
60
+
61
+ ids_with_annotations = funcy.lmap(lambda a: int(a["image_id"]), annotations)
62
+
63
+ # Images with annotations
64
+ img_ann = funcy.lremove(lambda i: i["id"] not in ids_with_annotations, images)
65
+ tr_ann, ts_ann = train_test_split(
66
+ img_ann, train_size=split_ratio, random_state=random_state
67
+ )
68
+
69
+ img_wo_ann = funcy.lremove(lambda i: i["id"] in ids_with_annotations, images)
70
+ if len(img_wo_ann) > 0:
71
+ tr_wo_ann, ts_wo_ann = train_test_split(
72
+ img_wo_ann, train_size=split_ratio, random_state=random_state
73
+ )
74
+ else:
75
+ tr_wo_ann, ts_wo_ann = [], [] # Images without annotations
76
+
77
+ if having_annotations:
78
+ tr, ts = tr_ann, ts_ann
79
+
80
+ else:
81
+ # Merging the 2 image lists (i.e. with and without annotation)
82
+ tr_ann.extend(tr_wo_ann)
83
+ ts_ann.extend(ts_wo_ann)
84
+
85
+ tr, ts = tr_ann, ts_ann
86
+
87
+ # Train Data
88
+ coco.update({"images": tr, "annotations": filter_annotations(annotations, tr)})
89
+ save_coco(train_save_path, coco)
90
+
91
+ # Test Data
92
+ coco.update({"images": ts, "annotations": filter_annotations(annotations, ts)})
93
+ save_coco(test_save_path, coco)
94
+
95
+ print(
96
+ "Saved {} entries in {} and {} in {}".format(
97
+ len(tr), train_save_path, len(ts), test_save_path
98
+ )
99
+ )
100
+
101
+
102
+ if __name__ == "__main__":
103
+ args = parser.parse_args()
104
+
105
+ main(
106
+ args.annotation_path,
107
+ args.split_ratio,
108
+ args.having_annotations,
109
+ args.train,
110
+ args.test,
111
+ random_state=24,
112
+ )