Shannon Shen commited on
Commit
c4d6291
·
2 Parent(s): 95245f5 29d3845

Merge pull request #9 from TITC/patch-1

Browse files
Files changed (1) hide show
  1. tools/train_net.py +58 -29
tools/train_net.py CHANGED
@@ -14,7 +14,13 @@ from detectron2.data import DatasetMapper, build_detection_train_loader
14
 
15
  from detectron2.data.datasets import register_coco_instances
16
 
17
- from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
 
 
 
 
 
 
18
  from detectron2.evaluation import (
19
  COCOEvaluator,
20
  verify_results,
@@ -25,12 +31,14 @@ import pandas as pd
25
 
26
  def get_augs(cfg):
27
  """Add all the desired augmentations here. A list of availble augmentations
28
- can be found here:
29
  https://detectron2.readthedocs.io/en/latest/modules/data_transforms.html
30
  """
31
  augs = [
32
  T.ResizeShortestEdge(
33
- cfg.INPUT.MIN_SIZE_TRAIN, cfg.INPUT.MAX_SIZE_TRAIN, cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING
 
 
34
  )
35
  ]
36
  if cfg.INPUT.CROP.ENABLED:
@@ -42,13 +50,13 @@ def get_augs(cfg):
42
  cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
43
  )
44
  )
45
- horizontal_flip: bool = (cfg.INPUT.RANDOM_FLIP == 'horizontal')
46
- augs.append(T.RandomFlip(horizontal=horizontal_flip,
47
- vertical=not horizontal_flip))
48
  # Rotate the image between -90 to 0 degrees clockwise around the centre
49
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
50
  return augs
51
 
 
52
  class Trainer(DefaultTrainer):
53
  """
54
  We use the "DefaultTrainer" which contains pre-defined default logic for
@@ -56,7 +64,7 @@ class Trainer(DefaultTrainer):
56
  are working on a new research project. In that case you can use the cleaner
57
  "SimpleTrainer", or write your own training loop. You can use
58
  "tools/plain_train_net.py" as an example.
59
-
60
  Adapted from:
61
  https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
62
  """
@@ -102,19 +110,21 @@ class Trainer(DefaultTrainer):
102
  for name in cfg.DATASETS.TEST
103
  ]
104
  res = cls.test(cfg, model, evaluators)
105
- pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
106
  return res
107
 
 
108
  def setup(args):
109
  """
110
  Create configs and perform basic setups.
111
  """
112
  cfg = get_cfg()
 
113
  if args.config_file != "":
114
  cfg.merge_from_file(args.config_file)
115
  cfg.merge_from_list(args.opts)
116
 
117
- with open(args.json_annotation_train, 'r') as fp:
118
  anno_file = json.load(fp)
119
 
120
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
@@ -122,13 +132,26 @@ def setup(args):
122
 
123
  cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
124
  cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
125
-
126
  cfg.freeze()
127
  default_setup(cfg, args)
128
  return cfg
129
 
130
 
131
  def main(args):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  cfg = setup(args)
133
 
134
  if args.eval_only:
@@ -137,14 +160,14 @@ def main(args):
137
  cfg.MODEL.WEIGHTS, resume=args.resume
138
  )
139
  res = Trainer.test(cfg, model)
140
-
141
  if cfg.TEST.AUG.ENABLED:
142
  res.update(Trainer.test_with_TTA(cfg, model))
143
  if comm.is_main_process():
144
  verify_results(cfg, res)
145
 
146
  # Save the evaluation results
147
- pd.DataFrame(res).to_csv(f'{cfg.OUTPUT_DIR}/eval.csv')
148
  return res
149
 
150
  # Ensure that the Output directory exists
@@ -158,7 +181,7 @@ def main(args):
158
  trainer = Trainer(cfg)
159
  trainer.resume_or_load(resume=args.resume)
160
  trainer.register_hooks(
161
- [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
162
  )
163
  if cfg.TEST.AUG.ENABLED:
164
  trainer.register_hooks(
@@ -171,24 +194,30 @@ if __name__ == "__main__":
171
  parser = default_argument_parser()
172
 
173
  # Extra Configurations for dataset names and paths
174
- parser.add_argument("--dataset_name", default="", help="The Dataset Name")
175
- parser.add_argument("--json_annotation_train", default="", metavar="FILE", help="The path to the training set JSON annotation")
176
- parser.add_argument("--image_path_train", default="", metavar="FILE", help="The path to the training set image folder")
177
- parser.add_argument("--json_annotation_val", default="", metavar="FILE", help="The path to the validation set JSON annotation")
178
- parser.add_argument("--image_path_val", default="", metavar="FILE", help="The path to the validation set image folder")
179
-
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  args = parser.parse_args()
181
  print("Command Line Args:", args)
182
 
183
- # Register Datasets
184
- dataset_name = args.dataset_name
185
- register_coco_instances(f"{dataset_name}-train", {},
186
- args.json_annotation_train,
187
- args.image_path_train)
188
-
189
- register_coco_instances(f"{dataset_name}-val", {},
190
- args.json_annotation_val,
191
- args.image_path_val)
192
 
193
  launch(
194
  main,
@@ -197,4 +226,4 @@ if __name__ == "__main__":
197
  machine_rank=args.machine_rank,
198
  dist_url=args.dist_url,
199
  args=(args,),
200
- )
 
14
 
15
  from detectron2.data.datasets import register_coco_instances
16
 
17
+ from detectron2.engine import (
18
+ DefaultTrainer,
19
+ default_argument_parser,
20
+ default_setup,
21
+ hooks,
22
+ launch,
23
+ )
24
  from detectron2.evaluation import (
25
  COCOEvaluator,
26
  verify_results,
 
31
 
32
  def get_augs(cfg):
33
  """Add all the desired augmentations here. A list of availble augmentations
34
+ can be found here:
35
  https://detectron2.readthedocs.io/en/latest/modules/data_transforms.html
36
  """
37
  augs = [
38
  T.ResizeShortestEdge(
39
+ cfg.INPUT.MIN_SIZE_TRAIN,
40
+ cfg.INPUT.MAX_SIZE_TRAIN,
41
+ cfg.INPUT.MIN_SIZE_TRAIN_SAMPLING,
42
  )
43
  ]
44
  if cfg.INPUT.CROP.ENABLED:
 
50
  cfg.MODEL.SEM_SEG_HEAD.IGNORE_VALUE,
51
  )
52
  )
53
+ horizontal_flip: bool = cfg.INPUT.RANDOM_FLIP == "horizontal"
54
+ augs.append(T.RandomFlip(horizontal=horizontal_flip, vertical=not horizontal_flip))
 
55
  # Rotate the image between -90 to 0 degrees clockwise around the centre
56
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
57
  return augs
58
 
59
+
60
  class Trainer(DefaultTrainer):
61
  """
62
  We use the "DefaultTrainer" which contains pre-defined default logic for
 
64
  are working on a new research project. In that case you can use the cleaner
65
  "SimpleTrainer", or write your own training loop. You can use
66
  "tools/plain_train_net.py" as an example.
67
+
68
  Adapted from:
69
  https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
70
  """
 
110
  for name in cfg.DATASETS.TEST
111
  ]
112
  res = cls.test(cfg, model, evaluators)
113
+ pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, "eval.csv"))
114
  return res
115
 
116
+
117
  def setup(args):
118
  """
119
  Create configs and perform basic setups.
120
  """
121
  cfg = get_cfg()
122
+
123
  if args.config_file != "":
124
  cfg.merge_from_file(args.config_file)
125
  cfg.merge_from_list(args.opts)
126
 
127
+ with open(args.json_annotation_train, "r") as fp:
128
  anno_file = json.load(fp)
129
 
130
  cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
 
132
 
133
  cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
134
  cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
 
135
  cfg.freeze()
136
  default_setup(cfg, args)
137
  return cfg
138
 
139
 
140
  def main(args):
141
+ # Register Datasets
142
+ register_coco_instances(
143
+ f"{args.dataset_name}-train",
144
+ {},
145
+ args.json_annotation_train,
146
+ args.image_path_train,
147
+ )
148
+
149
+ register_coco_instances(
150
+ f"{args.dataset_name}-val",
151
+ {},
152
+ args.json_annotation_val,
153
+ args.image_path_val
154
+ )
155
  cfg = setup(args)
156
 
157
  if args.eval_only:
 
160
  cfg.MODEL.WEIGHTS, resume=args.resume
161
  )
162
  res = Trainer.test(cfg, model)
163
+
164
  if cfg.TEST.AUG.ENABLED:
165
  res.update(Trainer.test_with_TTA(cfg, model))
166
  if comm.is_main_process():
167
  verify_results(cfg, res)
168
 
169
  # Save the evaluation results
170
+ pd.DataFrame(res).to_csv(f"{cfg.OUTPUT_DIR}/eval.csv")
171
  return res
172
 
173
  # Ensure that the Output directory exists
 
181
  trainer = Trainer(cfg)
182
  trainer.resume_or_load(resume=args.resume)
183
  trainer.register_hooks(
184
+ [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
185
  )
186
  if cfg.TEST.AUG.ENABLED:
187
  trainer.register_hooks(
 
194
  parser = default_argument_parser()
195
 
196
  # Extra Configurations for dataset names and paths
197
+ parser.add_argument(
198
+ "--dataset_name",
199
+ help="The Dataset Name")
200
+ parser.add_argument(
201
+ "--json_annotation_train",
202
+ help="The path to the training set JSON annotation",
203
+ )
204
+ parser.add_argument(
205
+ "--image_path_train",
206
+ help="The path to the training set image folder",
207
+ )
208
+ parser.add_argument(
209
+ "--json_annotation_val",
210
+ help="The path to the validation set JSON annotation",
211
+ )
212
+ parser.add_argument(
213
+ "--image_path_val",
214
+ help="The path to the validation set image folder",
215
+ )
216
  args = parser.parse_args()
217
  print("Command Line Args:", args)
218
 
219
+ # Dataset Registration is moved to the main function to support multi-gpu training
220
+ # See ref https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517
 
 
 
 
 
 
 
221
 
222
  launch(
223
  main,
 
226
  machine_rank=args.machine_rank,
227
  dist_url=args.dist_url,
228
  args=(args,),
229
+ )