HichTala commited on
Commit
a60ce1a
·
verified ·
1 Parent(s): 4c81f22

Upload configuration_diffusiondet.py

Browse files
Files changed (1) hide show
  1. configuration_diffusiondet.py +167 -0
configuration_diffusiondet.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ from transformers.models.auto import CONFIG_MAPPING
4
+ from transformers.utils.backbone_utils import verify_backbone_config_arguments
5
+
6
+ from transformers.utils import logging, PushToHubMixin
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ class DiffusionDetConfig(PretrainedConfig):
11
+
12
+ model_type = "diffusiondet"
13
+
14
+ def __init__(
15
+ self,
16
+ use_timm_backbone=True,
17
+ backbone_config=None,
18
+ num_channels=3,
19
+ pixel_mean=(123.675, 116.280, 103.530),
20
+ pixel_std=(58.395, 57.120, 57.375),
21
+ resnet_out_features=("res2", "res3", "res4", "res5"),
22
+ resnet_in_features=("res2", "res3", "res4", "res5"),
23
+ roi_head_in_features=("p2", "p3", "p4", "p5"),
24
+ fpn_out_channels=256,
25
+ pooler_resolution=7,
26
+ sampling_ratio=2,
27
+ num_proposals=300,
28
+ num_attn_heads=8,
29
+ dropout=0.0,
30
+ dim_feedforward=2048,
31
+ activation="relu",
32
+ hidden_dim=256,
33
+ num_cls=1,
34
+ num_reg=3,
35
+ num_heads=6,
36
+ num_dynamic=2,
37
+ dim_dynamic=64,
38
+ class_weight=2.0,
39
+ giou_weight=2.0,
40
+ l1_weight=5.0,
41
+ deep_supervision=True,
42
+ no_object_weight=0.1,
43
+ use_focal=True,
44
+ use_fed_loss=False,
45
+ alpha=0.25,
46
+ gamma=2.0,
47
+ prior_prob=0.01,
48
+ ota_k=5,
49
+ snr_scale=2.0,
50
+ sample_step=1,
51
+ use_nms=True,
52
+ swin_size="B",
53
+ use_swin_checkpoint=False,
54
+ swin_out_features=(0, 1, 2, 3),
55
+ optimizer="ADAMW",
56
+ backbone_multiplier=1.0,
57
+ backbone='resnet50',
58
+ use_pretrained_backbone=True,
59
+ backbone_kwargs=None,
60
+ dilation=False,
61
+ **kwargs
62
+ ):
63
+ # We default to values which were previously hard-coded in the model. This enables configurability of the config
64
+ # while keeping the default behavior the same.
65
+ if use_timm_backbone and backbone_kwargs is None:
66
+ backbone_kwargs = {}
67
+ if dilation:
68
+ backbone_kwargs["output_stride"] = 16
69
+ backbone_kwargs["out_indices"] = [1, 2, 3, 4]
70
+ backbone_kwargs["in_chans"] = num_channels
71
+ # Backwards compatibility
72
+ elif not use_timm_backbone and backbone in (None, "resnet50"):
73
+ if backbone_config is None:
74
+ logger.info("`backbone_config` is `None`. Initializing the config with the default `ResNet` backbone.")
75
+ backbone_config = CONFIG_MAPPING["resnet"](out_features=["stage4"])
76
+ elif isinstance(backbone_config, dict):
77
+ backbone_model_type = backbone_config.get("model_type")
78
+ config_class = CONFIG_MAPPING[backbone_model_type]
79
+ backbone_config = config_class.from_dict(backbone_config)
80
+ backbone = None
81
+ # set timm attributes to None
82
+ dilation = None
83
+
84
+ verify_backbone_config_arguments(
85
+ use_timm_backbone=use_timm_backbone,
86
+ use_pretrained_backbone=use_pretrained_backbone,
87
+ backbone=backbone,
88
+ backbone_config=backbone_config,
89
+ backbone_kwargs=backbone_kwargs,
90
+ )
91
+
92
+ # Auto mapping
93
+ self.auto_map = {
94
+ "AutoConfig": "diffusiondet.configuration_diffusiondet.DiffusionDetConfig",
95
+ "AutoModelForObjectDetection": "diffusiondet.modeling_diffusiondet.DiffusionDet"
96
+ }
97
+
98
+ # Backbone.
99
+ self.use_timm_backbone = use_timm_backbone
100
+ self.backbone_config = backbone_config
101
+ self.num_channels = num_channels
102
+ self.backbone = backbone
103
+ self.use_pretrained_backbone = use_pretrained_backbone
104
+ self.backbone_kwargs = backbone_kwargs
105
+ self.dilation = dilation
106
+ self.fpn_out_channels = fpn_out_channels
107
+
108
+ # Model.
109
+ self.pixel_mean = pixel_mean
110
+ self.pixel_std = pixel_std
111
+ self.resnet_out_features = resnet_out_features
112
+ self.resnet_in_features = resnet_in_features
113
+ self.roi_head_in_features = roi_head_in_features
114
+ self.pooler_resolution = pooler_resolution
115
+ self.sampling_ratio = sampling_ratio
116
+ self.num_proposals = num_proposals
117
+
118
+ # RCNN Head.
119
+ self.num_attn_heads = num_attn_heads
120
+ self.dropout = dropout
121
+ self.dim_feedforward = dim_feedforward
122
+ self.activation = activation
123
+ self.hidden_dim = hidden_dim
124
+ self.num_cls = num_cls
125
+ self.num_reg = num_reg
126
+ self.num_heads = num_heads
127
+
128
+ # Dynamic Conv.
129
+ self.num_dynamic = num_dynamic
130
+ self.dim_dynamic = dim_dynamic
131
+
132
+ # Loss.
133
+ self.class_weight = class_weight
134
+ self.giou_weight = giou_weight
135
+ self.l1_weight = l1_weight
136
+ self.deep_supervision = deep_supervision
137
+ self.no_object_weight = no_object_weight
138
+
139
+ # Focal Loss.
140
+ self.use_focal = use_focal
141
+ self.use_fed_loss = use_fed_loss
142
+ self.alpha = alpha
143
+ self.gamma = gamma
144
+ self.prior_prob = prior_prob
145
+
146
+ # Dynamic K
147
+ self.ota_k = ota_k
148
+
149
+ # Diffusion
150
+ self.snr_scale = snr_scale
151
+ self.sample_step = sample_step
152
+
153
+ # Inference
154
+ self.use_nms = use_nms
155
+
156
+ # Swin Backbones
157
+ self.swin_size = swin_size
158
+ self.use_swin_checkpoint = use_swin_checkpoint
159
+ self.swin_out_features = swin_out_features
160
+
161
+ # Optimizer.
162
+ self.optimizer = optimizer
163
+ self.backbone_multiplier = backbone_multiplier
164
+
165
+ self.num_labels = 80
166
+
167
+ super().__init__()