Shannon Shen commited on
Commit
6cb8ba2
·
1 Parent(s): 467f054

a better solution

Browse files
Files changed (1) hide show
  1. utils/cocosplit.py +9 -6
utils/cocosplit.py CHANGED
@@ -6,7 +6,7 @@ import funcy
6
  from sklearn.model_selection import train_test_split
7
 
8
  parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.')
9
- parser.add_argument('--annotation_path', metavar='coco_annotations', type=str,
10
  help='Path to COCO annotations file.')
11
  parser.add_argument('--train', type=str, help='Where to store COCO training annotations')
12
  parser.add_argument('--test', type=str, help='Where to store COCO test annotations')
@@ -43,15 +43,18 @@ def main(annotation_path,
43
  tr_ann, ts_ann = train_test_split(img_ann, train_size=split_ratio,
44
  random_state=random_state)
45
 
 
 
 
 
 
 
 
 
46
  if having_annotations:
47
  tr, ts = tr_ann, ts_ann
48
 
49
  else:
50
- # Images without annotations
51
- img_wo_ann = funcy.lremove(lambda i: i['id'] in ids_with_annotations, images)
52
- tr_wo_ann, ts_wo_ann = train_test_split(img_wo_ann, train_size=split_ratio,
53
- random_state=random_state)
54
-
55
  # Merging the 2 image lists (i.e. with and without annotation)
56
  tr_ann.extend(tr_wo_ann)
57
  ts_ann.extend(ts_wo_ann)
 
6
  from sklearn.model_selection import train_test_split
7
 
8
  parser = argparse.ArgumentParser(description='Splits COCO annotations file into training and test sets.')
9
+ parser.add_argument('--annotation-path', metavar='coco_annotations', type=str,
10
  help='Path to COCO annotations file.')
11
  parser.add_argument('--train', type=str, help='Where to store COCO training annotations')
12
  parser.add_argument('--test', type=str, help='Where to store COCO test annotations')
 
43
  tr_ann, ts_ann = train_test_split(img_ann, train_size=split_ratio,
44
  random_state=random_state)
45
 
46
+ img_wo_ann = funcy.lremove(lambda i: i['id'] in ids_with_annotations, images)
47
+ if len(img_wo_ann) > 0:
48
+ tr_wo_ann, ts_wo_ann = train_test_split(img_wo_ann, train_size=split_ratio,
49
+ random_state=random_state)
50
+ else:
51
+ tr_wo_ann, ts_wo_ann = [], []# Images without annotations
52
+
53
+
54
  if having_annotations:
55
  tr, ts = tr_ann, ts_ann
56
 
57
  else:
 
 
 
 
 
58
  # Merging the 2 image lists (i.e. with and without annotation)
59
  tr_ann.extend(tr_wo_ann)
60
  ts_ann.extend(ts_wo_ann)