loliipopshock commited on
Commit
e483cda
·
1 Parent(s): f9b14aa

Add training scripts

Browse files
Files changed (1) hide show
  1. tools/train_net.py +150 -0
tools/train_net.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The script is based on https://github.com/facebookresearch/detectron2/blob/master/tools/train_net.py.
3
+ """
4
+
5
+ import logging
6
+ import os
7
+ import json
8
+ from collections import OrderedDict
9
+ import torch
10
+ import sys
11
+ import detectron2.utils.comm as comm
12
+ from detectron2.checkpoint import DetectionCheckpointer
13
+ from detectron2.config import get_cfg
14
+
15
+ from detectron2.data import MetadataCatalog, DatasetCatalog
16
+ from detectron2.data.datasets import register_coco_instances
17
+
18
+ from detectron2.engine import DefaultTrainer, default_argument_parser, default_setup, hooks, launch
19
+ from detectron2.evaluation import (
20
+ COCOEvaluator,
21
+ DatasetEvaluators,
22
+ SemSegEvaluator,
23
+ verify_results,
24
+ )
25
+ from detectron2.modeling import GeneralizedRCNNWithTTA
26
+ import pandas as pd
27
+
28
+ class Trainer(DefaultTrainer):
29
+ """
30
+ We use the "DefaultTrainer" which contains pre-defined default logic for
31
+ standard training workflow. They may not work for you, especially if you
32
+ are working on a new research project. In that case you can use the cleaner
33
+ "SimpleTrainer", or write your own training loop. You can use
34
+ "tools/plain_train_net.py" as an example.
35
+ """
36
+
37
+ @classmethod
38
+ def build_evaluator(cls, cfg, dataset_name, output_folder=None):
39
+ """
40
+ Returns:
41
+ DatasetEvaluator or None
42
+
43
+ It is not implemented by default.
44
+ """
45
+ return COCOEvaluator(dataset_name, cfg, True, output_folder)
46
+
47
+ @classmethod
48
+ def test_with_TTA(cls, cfg, model):
49
+ logger = logging.getLogger("detectron2.trainer")
50
+ # In the end of training, run an evaluation with TTA
51
+ # Only support some R-CNN models.
52
+ logger.info("Running inference with test-time augmentation ...")
53
+ model = GeneralizedRCNNWithTTA(cfg, model)
54
+ evaluators = [
55
+ cls.build_evaluator(
56
+ cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA")
57
+ )
58
+ for name in cfg.DATASETS.TEST
59
+ ]
60
+ res = cls.test(cfg, model, evaluators)
61
+ res = OrderedDict({k + "_TTA": v for k, v in res.items()})
62
+ return res
63
+
64
+
65
+ def setup(args):
66
+ """
67
+ Create configs and perform basic setups.
68
+ """
69
+ cfg = get_cfg()
70
+ cfg.merge_from_file(args.config_file)
71
+ cfg.merge_from_list(args.opts)
72
+
73
+ with open(args.json_annotation_train, 'r') as fp:
74
+ anno_file = json.load(fp)
75
+
76
+ cfg.MODEL.ROI_HEADS.NUM_CLASSES = len(anno_file["categories"])
77
+ del anno_file
78
+
79
+ cfg.DATASETS.TRAIN = (f"{args.dataset_name}-train",)
80
+ cfg.DATASETS.TEST = (f"{args.dataset_name}-val",)
81
+
82
+ cfg.freeze()
83
+ default_setup(cfg, args)
84
+ return cfg
85
+
86
+
87
+ def main(args):
88
+ cfg = setup(args)
89
+
90
+ if args.eval_only:
91
+ model = Trainer.build_model(cfg)
92
+ DetectionCheckpointer(model, save_dir=cfg.OUTPUT_DIR).resume_or_load(
93
+ cfg.MODEL.WEIGHTS, resume=args.resume
94
+ )
95
+ res = Trainer.test(cfg, model)
96
+
97
+ if cfg.TEST.AUG.ENABLED:
98
+ res.update(Trainer.test_with_TTA(cfg, model))
99
+ if comm.is_main_process():
100
+ verify_results(cfg, res)
101
+
102
+ # Save the evaluation results
103
+ pd.DataFrame(res).to_csv(f'{cfg.OUTPUT_DIR}/eval.csv')
104
+ return res
105
+
106
+ """
107
+ If you'd like to do anything fancier than the standard training logic,
108
+ consider writing your own training loop (see plain_train_net.py) or
109
+ subclassing the trainer.
110
+ """
111
+ trainer = Trainer(cfg)
112
+ trainer.resume_or_load(resume=args.resume)
113
+ if cfg.TEST.AUG.ENABLED:
114
+ trainer.register_hooks(
115
+ [hooks.EvalHook(0, lambda: trainer.test_with_TTA(cfg, trainer.model))]
116
+ )
117
+ return trainer.train()
118
+
119
+
120
+ if __name__ == "__main__":
121
+ parser = default_argument_parser()
122
+
123
+ # Extra Configurations for dataset names and paths
124
+ parser.add_argument("--dataset_name", default="", help="The Dataset Name")
125
+ parser.add_argument("--json_annotation_train", default="", metavar="FILE", help="The path to the training set JSON annotation")
126
+ parser.add_argument("--image_path_train", default="", metavar="FILE", help="The path to the training set image folder")
127
+ parser.add_argument("--json_annotation_val", default="", metavar="FILE", help="The path to the validation set JSON annotation")
128
+ parser.add_argument("--image_path_val", default="", metavar="FILE", help="The path to the validation set image folder")
129
+
130
+ args = parser.parse_args()
131
+ print("Command Line Args:", args)
132
+
133
+ # Register Datasets
134
+ dataset_name = args.dataset_name
135
+ register_coco_instances(f"{dataset_name}-train", {},
136
+ args.json_annotation_train,
137
+ args.image_path_train)
138
+
139
+ register_coco_instances(f"{dataset_name}-val", {},
140
+ args.json_annotation_val,
141
+ args.image_path_val)
142
+
143
+ launch(
144
+ main,
145
+ args.num_gpus,
146
+ num_machines=args.num_machines,
147
+ machine_rank=args.machine_rank,
148
+ dist_url=args.dist_url,
149
+ args=(args,),
150
+ )