Shannon Shen commited on
Commit
410a3d6
·
1 Parent(s): 6cb8ba2

black formatting

Browse files
Files changed (1) hide show
  1. utils/cocosplit.py +69 -45
utils/cocosplit.py CHANGED
@@ -5,52 +5,75 @@ import argparse
5
  import funcy
6
  from sklearn.model_selection import train_test_split
7
 
8
- parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.')
9
- parser.add_argument('--annotation-path', metavar='coco_annotations', type=str,
10
- help='Path to COCO annotations file.')
11
- parser.add_argument('--train', type=str, help='Where to store COCO training annotations')
12
- parser.add_argument('--test', type=str, help='Where to store COCO test annotations')
13
- parser.add_argument('--split-ratio', dest='split_ratio', type=float, required=True,
14
- help="A percentage of a split; a number in (0, 1)")
15
- parser.add_argument('--having-annotations', dest='having_annotations', action='store_true',
16
- help='Ignore all images without annotations. Keep only these with at least one annotation')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def save_coco(file, tagged_data):
19
- with open(file, 'wt', encoding='UTF-8') as coco:
20
  json.dump(tagged_data, coco, indent=2, sort_keys=True)
21
 
 
22
  def filter_annotations(annotations, images):
23
- image_ids = funcy.lmap(lambda i: int(i['id']), images)
24
- return funcy.lfilter(lambda a: int(a['image_id']) in image_ids, annotations)
25
 
26
- def main(annotation_path,
27
- split_ratio,
28
- having_annotations,
29
- train_save_path,
30
- test_save_path,
31
- random_state=None):
32
 
33
- with open(annotation_path, 'rt', encoding='UTF-8') as annotations:
 
 
 
 
 
 
 
 
 
34
  coco = json.load(annotations)
35
 
36
- images = coco['images']
37
- annotations = coco['annotations']
38
 
39
- ids_with_annotations = funcy.lmap(lambda a: int(a['image_id']), annotations)
40
 
41
  # Images with annotations
42
- img_ann = funcy.lremove(lambda i: i['id'] not in ids_with_annotations, images)
43
- tr_ann, ts_ann = train_test_split(img_ann, train_size=split_ratio,
44
- random_state=random_state)
 
45
 
46
- img_wo_ann = funcy.lremove(lambda i: i['id'] in ids_with_annotations, images)
47
  if len(img_wo_ann) > 0:
48
- tr_wo_ann, ts_wo_ann = train_test_split(img_wo_ann, train_size=split_ratio,
49
- random_state=random_state)
 
50
  else:
51
- tr_wo_ann, ts_wo_ann = [], []# Images without annotations
52
-
53
-
54
  if having_annotations:
55
  tr, ts = tr_ann, ts_ann
56
 
@@ -62,27 +85,28 @@ def main(annotation_path,
62
  tr, ts = tr_ann, ts_ann
63
 
64
  # Train Data
65
- coco.update({'images': tr,
66
- 'annotations': filter_annotations(annotations, tr)})
67
  save_coco(train_save_path, coco)
68
 
69
  # Test Data
70
- coco.update({'images': ts,
71
- 'annotations': filter_annotations(annotations, ts)})
72
  save_coco(test_save_path, coco)
73
 
74
- print("Saved {} entries in {} and {} in {}".format(len(tr),
75
- train_save_path,
76
- len(ts),
77
- test_save_path))
 
78
 
79
 
80
  if __name__ == "__main__":
81
  args = parser.parse_args()
82
 
83
- main(args.annotation_path,
84
- args.split_ratio,
85
- args.having_annotations,
86
- args.train,
87
- args.test,
88
- random_state=24)
 
 
 
5
  import funcy
6
  from sklearn.model_selection import train_test_split
7
 
8
+ parser = argparse.ArgumentParser(
9
+ description="Splits COCO annotations file into training and test sets."
10
+ )
11
+ parser.add_argument(
12
+ "--annotation-path",
13
+ metavar="coco_annotations",
14
+ type=str,
15
+ help="Path to COCO annotations file.",
16
+ )
17
+ parser.add_argument(
18
+ "--train", type=str, help="Where to store COCO training annotations"
19
+ )
20
+ parser.add_argument("--test", type=str, help="Where to store COCO test annotations")
21
+ parser.add_argument(
22
+ "--split-ratio",
23
+ dest="split_ratio",
24
+ type=float,
25
+ required=True,
26
+ help="A percentage of a split; a number in (0, 1)",
27
+ )
28
+ parser.add_argument(
29
+ "--having-annotations",
30
+ dest="having_annotations",
31
+ action="store_true",
32
+ help="Ignore all images without annotations. Keep only these with at least one annotation",
33
+ )
34
+
35
 
36
  def save_coco(file, tagged_data):
37
+ with open(file, "wt", encoding="UTF-8") as coco:
38
  json.dump(tagged_data, coco, indent=2, sort_keys=True)
39
 
40
+
41
  def filter_annotations(annotations, images):
42
+ image_ids = funcy.lmap(lambda i: int(i["id"]), images)
43
+ return funcy.lfilter(lambda a: int(a["image_id"]) in image_ids, annotations)
44
 
 
 
 
 
 
 
45
 
46
+ def main(
47
+ annotation_path,
48
+ split_ratio,
49
+ having_annotations,
50
+ train_save_path,
51
+ test_save_path,
52
+ random_state=None,
53
+ ):
54
+
55
+ with open(annotation_path, "rt", encoding="UTF-8") as annotations:
56
  coco = json.load(annotations)
57
 
58
+ images = coco["images"]
59
+ annotations = coco["annotations"]
60
 
61
+ ids_with_annotations = funcy.lmap(lambda a: int(a["image_id"]), annotations)
62
 
63
  # Images with annotations
64
+ img_ann = funcy.lremove(lambda i: i["id"] not in ids_with_annotations, images)
65
+ tr_ann, ts_ann = train_test_split(
66
+ img_ann, train_size=split_ratio, random_state=random_state
67
+ )
68
 
69
+ img_wo_ann = funcy.lremove(lambda i: i["id"] in ids_with_annotations, images)
70
  if len(img_wo_ann) > 0:
71
+ tr_wo_ann, ts_wo_ann = train_test_split(
72
+ img_wo_ann, train_size=split_ratio, random_state=random_state
73
+ )
74
  else:
75
+ tr_wo_ann, ts_wo_ann = [], [] # Images without annotations
76
+
 
77
  if having_annotations:
78
  tr, ts = tr_ann, ts_ann
79
 
 
85
  tr, ts = tr_ann, ts_ann
86
 
87
  # Train Data
88
+ coco.update({"images": tr, "annotations": filter_annotations(annotations, tr)})
 
89
  save_coco(train_save_path, coco)
90
 
91
  # Test Data
92
+ coco.update({"images": ts, "annotations": filter_annotations(annotations, ts)})
 
93
  save_coco(test_save_path, coco)
94
 
95
+ print(
96
+ "Saved {} entries in {} and {} in {}".format(
97
+ len(tr), train_save_path, len(ts), test_save_path
98
+ )
99
+ )
100
 
101
 
102
  if __name__ == "__main__":
103
  args = parser.parse_args()
104
 
105
+ main(
106
+ args.annotation_path,
107
+ args.split_ratio,
108
+ args.having_annotations,
109
+ args.train,
110
+ args.test,
111
+ random_state=24,
112
+ )