Yuhang Tao
commited on
Commit
·
c0fc0e9
1
Parent(s):
24cbdc6
alter for multi-gpu training
Browse filesI encounterd this [error](https://github.com/Layout-Parser/layout-model-training/issues/8#issue-1101684484) and found a solution at [detectron2](https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517)
>@zsc1220 it seems the dataset is not registered when you use multiple GPUs. Where do you register the dataset? In train_net you might need to register it in the main() function.
- tools/train_net.py +41 -22
tools/train_net.py
CHANGED
@@ -49,6 +49,7 @@ def get_augs(cfg):
|
|
49 |
augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
|
50 |
return augs
|
51 |
|
|
|
52 |
class Trainer(DefaultTrainer):
|
53 |
"""
|
54 |
We use the "DefaultTrainer" which contains pre-defined default logic for
|
@@ -56,7 +57,7 @@ class Trainer(DefaultTrainer):
|
|
56 |
are working on a new research project. In that case you can use the cleaner
|
57 |
"SimpleTrainer", or write your own training loop. You can use
|
58 |
"tools/plain_train_net.py" as an example.
|
59 |
-
|
60 |
Adapted from:
|
61 |
https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
|
62 |
"""
|
@@ -85,7 +86,8 @@ class Trainer(DefaultTrainer):
|
|
85 |
model = GeneralizedRCNNWithTTA(cfg, model)
|
86 |
evaluators = [
|
87 |
cls.build_evaluator(
|
88 |
-
cfg, name, output_folder=os.path.join(
|
|
|
89 |
)
|
90 |
for name in cfg.DATASETS.TEST
|
91 |
]
|
@@ -97,7 +99,8 @@ class Trainer(DefaultTrainer):
|
|
97 |
def eval_and_save(cls, cfg, model):
|
98 |
evaluators = [
|
99 |
cls.build_evaluator(
|
100 |
-
cfg, name, output_folder=os.path.join(
|
|
|
101 |
)
|
102 |
for name in cfg.DATASETS.TEST
|
103 |
]
|
@@ -105,11 +108,14 @@ class Trainer(DefaultTrainer):
|
|
105 |
pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
|
106 |
return res
|
107 |
|
|
|
108 |
def setup(args):
|
109 |
"""
|
110 |
Create configs and perform basic setups.
|
111 |
"""
|
112 |
cfg = get_cfg()
|
|
|
|
|
113 |
if args.config_file != "":
|
114 |
cfg.merge_from_file(args.config_file)
|
115 |
cfg.merge_from_list(args.opts)
|
@@ -122,13 +128,21 @@ def setup(args):
|
|
122 |
|
123 |
cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
|
124 |
cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
|
125 |
-
|
126 |
cfg.freeze()
|
127 |
default_setup(cfg, args)
|
128 |
return cfg
|
129 |
|
130 |
|
131 |
def main(args):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
cfg = setup(args)
|
133 |
|
134 |
if args.eval_only:
|
@@ -137,7 +151,7 @@ def main(args):
|
|
137 |
cfg.MODEL.WEIGHTS, resume=args.resume
|
138 |
)
|
139 |
res = Trainer.test(cfg, model)
|
140 |
-
|
141 |
if cfg.TEST.AUG.ENABLED:
|
142 |
res.update(Trainer.test_with_TTA(cfg, model))
|
143 |
if comm.is_main_process():
|
@@ -158,11 +172,12 @@ def main(args):
|
|
158 |
trainer = Trainer(cfg)
|
159 |
trainer.resume_or_load(resume=args.resume)
|
160 |
trainer.register_hooks(
|
161 |
-
|
162 |
)
|
163 |
if cfg.TEST.AUG.ENABLED:
|
164 |
trainer.register_hooks(
|
165 |
-
[hooks.EvalHook(
|
|
|
166 |
)
|
167 |
return trainer.train()
|
168 |
|
@@ -171,24 +186,28 @@ if __name__ == "__main__":
|
|
171 |
parser = default_argument_parser()
|
172 |
|
173 |
# Extra Configurations for dataset names and paths
|
174 |
-
parser.add_argument("--dataset_name",
|
175 |
-
|
176 |
-
parser.add_argument("--
|
177 |
-
|
178 |
-
parser.add_argument("--
|
179 |
-
|
|
|
|
|
|
|
|
|
180 |
args = parser.parse_args()
|
181 |
print("Command Line Args:", args)
|
182 |
|
183 |
-
# Register Datasets
|
184 |
-
dataset_name = args.dataset_name
|
185 |
-
register_coco_instances(f"{dataset_name}-train", {},
|
186 |
-
|
187 |
-
|
188 |
|
189 |
-
register_coco_instances(f"{dataset_name}-val", {},
|
190 |
-
|
191 |
-
|
192 |
|
193 |
launch(
|
194 |
main,
|
@@ -197,4 +216,4 @@ if __name__ == "__main__":
|
|
197 |
machine_rank=args.machine_rank,
|
198 |
dist_url=args.dist_url,
|
199 |
args=(args,),
|
200 |
-
)
|
|
|
49 |
augs.append(T.RandomRotation(angle=[-90.0, 0.0]))
|
50 |
return augs
|
51 |
|
52 |
+
|
53 |
class Trainer(DefaultTrainer):
|
54 |
"""
|
55 |
We use the "DefaultTrainer" which contains pre-defined default logic for
|
|
|
57 |
are working on a new research project. In that case you can use the cleaner
|
58 |
"SimpleTrainer", or write your own training loop. You can use
|
59 |
"tools/plain_train_net.py" as an example.
|
60 |
+
|
61 |
Adapted from:
|
62 |
https://github.com/facebookresearch/detectron2/blob/master/projects/DeepLab/train_net.py
|
63 |
"""
|
|
|
86 |
model = GeneralizedRCNNWithTTA(cfg, model)
|
87 |
evaluators = [
|
88 |
cls.build_evaluator(
|
89 |
+
cfg, name, output_folder=os.path.join(
|
90 |
+
cfg.OUTPUT_DIR, "inference_TTA")
|
91 |
)
|
92 |
for name in cfg.DATASETS.TEST
|
93 |
]
|
|
|
99 |
def eval_and_save(cls, cfg, model):
|
100 |
evaluators = [
|
101 |
cls.build_evaluator(
|
102 |
+
cfg, name, output_folder=os.path.join(
|
103 |
+
cfg.OUTPUT_DIR, "inference")
|
104 |
)
|
105 |
for name in cfg.DATASETS.TEST
|
106 |
]
|
|
|
108 |
pd.DataFrame(res).to_csv(os.path.join(cfg.OUTPUT_DIR, 'eval.csv'))
|
109 |
return res
|
110 |
|
111 |
+
|
112 |
def setup(args):
|
113 |
"""
|
114 |
Create configs and perform basic setups.
|
115 |
"""
|
116 |
cfg = get_cfg()
|
117 |
+
# alter for multi-gpu training
|
118 |
+
#https://github.com/facebookresearch/detectron2/issues/253#issuecomment-554216517
|
119 |
if args.config_file != "":
|
120 |
cfg.merge_from_file(args.config_file)
|
121 |
cfg.merge_from_list(args.opts)
|
|
|
128 |
|
129 |
cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
|
130 |
cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
|
|
|
131 |
cfg.freeze()
|
132 |
default_setup(cfg, args)
|
133 |
return cfg
|
134 |
|
135 |
|
136 |
def main(args):
|
137 |
+
# Register Datasets
|
138 |
+
dataset_name = args.dataset_name
|
139 |
+
register_coco_instances(f"{args.dataset_name}-train", {},
|
140 |
+
args.json_annotation_train,
|
141 |
+
args.image_path_train)
|
142 |
+
|
143 |
+
register_coco_instances(f"{args.dataset_name}-val", {},
|
144 |
+
args.json_annotation_val,
|
145 |
+
args.image_path_val)
|
146 |
cfg = setup(args)
|
147 |
|
148 |
if args.eval_only:
|
|
|
151 |
cfg.MODEL.WEIGHTS, resume=args.resume
|
152 |
)
|
153 |
res = Trainer.test(cfg, model)
|
154 |
+
|
155 |
if cfg.TEST.AUG.ENABLED:
|
156 |
res.update(Trainer.test_with_TTA(cfg, model))
|
157 |
if comm.is_main_process():
|
|
|
172 |
trainer = Trainer(cfg)
|
173 |
trainer.resume_or_load(resume=args.resume)
|
174 |
trainer.register_hooks(
|
175 |
+
[hooks.EvalHook(0, lambda: trainer.eval_and_save(cfg, trainer.model))]
|
176 |
)
|
177 |
if cfg.TEST.AUG.ENABLED:
|
178 |
trainer.register_hooks(
|
179 |
+
[hooks.EvalHook(
|
180 |
+
0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
|
181 |
)
|
182 |
return trainer.train()
|
183 |
|
|
|
186 |
parser = default_argument_parser()
|
187 |
|
188 |
# Extra Configurations for dataset names and paths
|
189 |
+
parser.add_argument("--dataset_name",
|
190 |
+
default="", help="The Dataset Name")
|
191 |
+
parser.add_argument("--json_annotation_train", default="", metavar="FILE",
|
192 |
+
help="The path to the training set JSON annotation")
|
193 |
+
parser.add_argument("--image_path_train", default="",
|
194 |
+
metavar="FILE", help="The path to the training set image folder")
|
195 |
+
parser.add_argument("--json_annotation_val", default="", metavar="FILE",
|
196 |
+
help="The path to the validation set JSON annotation")
|
197 |
+
parser.add_argument("--image_path_val", default="",
|
198 |
+
metavar="FILE", help="The path to the validation set image folder")
|
199 |
args = parser.parse_args()
|
200 |
print("Command Line Args:", args)
|
201 |
|
202 |
+
# # Register Datasets
|
203 |
+
# dataset_name = args.dataset_name
|
204 |
+
# register_coco_instances(f"{dataset_name}-train", {},
|
205 |
+
# args.json_annotation_train,
|
206 |
+
# args.image_path_train)
|
207 |
|
208 |
+
# register_coco_instances(f"{dataset_name}-val", {},
|
209 |
+
# args.json_annotation_val,
|
210 |
+
# args.image_path_val)
|
211 |
|
212 |
launch(
|
213 |
main,
|
|
|
216 |
machine_rank=args.machine_rank,
|
217 |
dist_url=args.dist_url,
|
218 |
args=(args,),
|
219 |
+
)
|