Yuhang Tao commited on
Commit
c0fc0e9
·
1 Parent(s): 24cbdc6

alter for multi-gpu training

Browse files

I encounterd this [error](https://github.com/Layout-Parser/layout-model-training/issues/8#issue-1101684484) and found a solution at [detectron2](https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517)

>@zsc1220 it seems the dataset is not registered when you use multiple GPUs. Where do you register the dataset? In train_net you might need to register it in the main() function.

Files changed (1) hide show
  1. tools/train_net.py +41 -22
tools/train_net.py CHANGED
@@ -49,6 +49,7 @@ def get_augs(cfg):
49
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
50
  return augs
51
 
 
52
  class Trainer(DefaultTrainer):
53
  """
54
  We use the "DefaultTrainer" which contains pre-defined default logic for
@@ -56,7 +57,7 @@ class Trainer(DefaultTrainer):
56
  are working on a new research project. In that case you can use the cleaner
57
  "SimpleTrainer", or write your own training loop. You can use
58
  "tools/plain_train_net.py" as an example.
59
-
60
  Adapted from:
61
  https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
62
  """
@@ -85,7 +86,8 @@ class Trainer(DefaultTrainer):
85
  model = GeneralizedRCNNWithTTA(cfg, model)
86
  evaluators = [
87
  cls.build_evaluator(
88
- cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
 
89
  )
90
  for name in cfg.DATASETS.TEST
91
  ]
@@ -97,7 +99,8 @@ class Trainer(DefaultTrainer):
97
  def eval_and_save(cls, cfg, model):
98
  evaluators = [
99
  cls.build_evaluator(
100
- cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference")
 
101
  )
102
  for name in cfg.DATASETS.TEST
103
  ]
@@ -105,11 +108,14 @@ class Trainer(DefaultTrainer):
105
  pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
106
  return res
107
 
 
108
  def setup(args):
109
  """
110
  Create configs and perform basic setups.
111
  """
112
  cfg = get_cfg()
 
 
113
  if args.config_file != "":
114
  cfg.merge_from_file(args.config_file)
115
  cfg.merge_from_list(args.opts)
@@ -122,13 +128,21 @@ def setup(args):
122
 
123
  cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
124
  cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
125
-
126
  cfg.freeze()
127
  default_setup(cfg, args)
128
  return cfg
129
 
130
 
131
  def main(args):
 
 
 
 
 
 
 
 
 
132
  cfg = setup(args)
133
 
134
  if args.eval_only:
@@ -137,7 +151,7 @@ def main(args):
137
  cfg.MODEL.WEIGHTS, resume=args.resume
138
  )
139
  res = Trainer.test(cfg, model)
140
-
141
  if cfg.TEST.AUG.ENABLED:
142
  res.update(Trainer.test_with_TTA(cfg, model))
143
  if comm.is_main_process():
@@ -158,11 +172,12 @@ def main(args):
158
  trainer = Trainer(cfg)
159
  trainer.resume_or_load(resume=args.resume)
160
  trainer.register_hooks(
161
- [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
162
  )
163
  if cfg.TEST.AUG.ENABLED:
164
  trainer.register_hooks(
165
- [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
 
166
  )
167
  return trainer.train()
168
 
@@ -171,24 +186,28 @@ if __name__ == "__main__":
171
  parser = default_argument_parser()
172
 
173
  # Extra Configurations for dataset names and paths
174
- parser.add_argument("--dataset_name", default="", help="The Dataset Name")
175
- parser.add_argument("--json_annotation_train", default="", metavar="FILE", help="The path to the training set JSON annotation")
176
- parser.add_argument("--image_path_train", default="", metavar="FILE", help="The path to the training set image folder")
177
- parser.add_argument("--json_annotation_val", default="", metavar="FILE", help="The path to the validation set JSON annotation")
178
- parser.add_argument("--image_path_val", default="", metavar="FILE", help="The path to the validation set image folder")
179
-
 
 
 
 
180
  args = parser.parse_args()
181
  print("Command Line Args:", args)
182
 
183
- # Register Datasets
184
- dataset_name = args.dataset_name
185
- register_coco_instances(f"{dataset_name}-train", {},
186
- args.json_annotation_train,
187
- args.image_path_train)
188
 
189
- register_coco_instances(f"{dataset_name}-val", {},
190
- args.json_annotation_val,
191
- args.image_path_val)
192
 
193
  launch(
194
  main,
@@ -197,4 +216,4 @@ if __name__ == "__main__":
197
  machine_rank=args.machine_rank,
198
  dist_url=args.dist_url,
199
  args=(args,),
200
- )
 
49
  augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
50
  return augs
51
 
52
+
53
  class Trainer(DefaultTrainer):
54
  """
55
  We use the "DefaultTrainer" which contains pre-defined default logic for
 
57
  are working on a new research project. In that case you can use the cleaner
58
  "SimpleTrainer", or write your own training loop. You can use
59
  "tools/plain_train_net.py" as an example.
60
+
61
  Adapted from:
62
  https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
63
  """
 
86
  model = GeneralizedRCNNWithTTA(cfg, model)
87
  evaluators = [
88
  cls.build_evaluator(
89
+ cfg, name, output_folder=os.path.join(
90
+ cfg.OUTPUT_DIR, "inference_TTA")
91
  )
92
  for name in cfg.DATASETS.TEST
93
  ]
 
99
  def eval_and_save(cls, cfg, model):
100
  evaluators = [
101
  cls.build_evaluator(
102
+ cfg, name, output_folder=os.path.join(
103
+ cfg.OUTPUT_DIR, "inference")
104
  )
105
  for name in cfg.DATASETS.TEST
106
  ]
 
108
  pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
109
  return res
110
 
111
+
112
  def setup(args):
113
  """
114
  Create configs and perform basic setups.
115
  """
116
  cfg = get_cfg()
117
+ # alter for multi-gpu training
118
+ #https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517
119
  if args.config_file != "":
120
  cfg.merge_from_file(args.config_file)
121
  cfg.merge_from_list(args.opts)
 
128
 
129
  cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
130
  cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
 
131
  cfg.freeze()
132
  default_setup(cfg, args)
133
  return cfg
134
 
135
 
136
  def main(args):
137
+ # Register Datasets
138
+ dataset_name = args.dataset_name
139
+ register_coco_instances(f"{args.dataset_name}-train", {},
140
+ args.json_annotation_train,
141
+ args.image_path_train)
142
+
143
+ register_coco_instances(f"{args.dataset_name}-val", {},
144
+ args.json_annotation_val,
145
+ args.image_path_val)
146
  cfg = setup(args)
147
 
148
  if args.eval_only:
 
151
  cfg.MODEL.WEIGHTS, resume=args.resume
152
  )
153
  res = Trainer.test(cfg, model)
154
+
155
  if cfg.TEST.AUG.ENABLED:
156
  res.update(Trainer.test_with_TTA(cfg, model))
157
  if comm.is_main_process():
 
172
  trainer = Trainer(cfg)
173
  trainer.resume_or_load(resume=args.resume)
174
  trainer.register_hooks(
175
+ [hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
176
  )
177
  if cfg.TEST.AUG.ENABLED:
178
  trainer.register_hooks(
179
+ [hooks.EvalHook(
180
+ 0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
181
  )
182
  return trainer.train()
183
 
 
186
  parser = default_argument_parser()
187
 
188
  # Extra Configurations for dataset names and paths
189
+ parser.add_argument("--dataset_name",
190
+ default="", help="The Dataset Name")
191
+ parser.add_argument("--json_annotation_train", default="", metavar="FILE",
192
+ help="The path to the training set JSON annotation")
193
+ parser.add_argument("--image_path_train", default="",
194
+ metavar="FILE", help="The path to the training set image folder")
195
+ parser.add_argument("--json_annotation_val", default="", metavar="FILE",
196
+ help="The path to the validation set JSON annotation")
197
+ parser.add_argument("--image_path_val", default="",
198
+ metavar="FILE", help="The path to the validation set image folder")
199
  args = parser.parse_args()
200
  print("Command Line Args:", args)
201
 
202
+ # # Register Datasets
203
+ # dataset_name = args.dataset_name
204
+ # register_coco_instances(f"{dataset_name}-train", {},
205
+ # args.json_annotation_train,
206
+ # args.image_path_train)
207
 
208
+ # register_coco_instances(f"{dataset_name}-val", {},
209
+ # args.json_annotation_val,
210
+ # args.image_path_val)
211
 
212
  launch(
213
  main,
 
216
  machine_rank=args.machine_rank,
217
  dist_url=args.dist_url,
218
  args=(args,),
219
+ )