JasonSmithSO commited on
Commit
0a3dbb2
·
verified ·
1 Parent(s): 0034848

Upload 564 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. custom_midas_repo/LICENSE +21 -0
  2. custom_midas_repo/README.md +259 -0
  3. custom_midas_repo/__init__.py +0 -0
  4. custom_midas_repo/hubconf.py +435 -0
  5. custom_midas_repo/midas/__init__.py +0 -0
  6. custom_midas_repo/midas/backbones/__init__.py +0 -0
  7. custom_midas_repo/midas/backbones/beit.py +196 -0
  8. custom_midas_repo/midas/backbones/levit.py +106 -0
  9. custom_midas_repo/midas/backbones/next_vit.py +39 -0
  10. custom_midas_repo/midas/backbones/swin.py +13 -0
  11. custom_midas_repo/midas/backbones/swin2.py +34 -0
  12. custom_midas_repo/midas/backbones/swin_common.py +52 -0
  13. custom_midas_repo/midas/backbones/utils.py +249 -0
  14. custom_midas_repo/midas/backbones/vit.py +221 -0
  15. custom_midas_repo/midas/base_model.py +16 -0
  16. custom_midas_repo/midas/blocks.py +439 -0
  17. custom_midas_repo/midas/dpt_depth.py +166 -0
  18. custom_midas_repo/midas/midas_net.py +76 -0
  19. custom_midas_repo/midas/midas_net_custom.py +128 -0
  20. custom_midas_repo/midas/model_loader.py +242 -0
  21. custom_midas_repo/midas/transforms.py +234 -0
  22. custom_mmpkg/__init__.py +1 -0
  23. custom_mmpkg/custom_mmcv/__init__.py +15 -0
  24. custom_mmpkg/custom_mmcv/arraymisc/__init__.py +4 -0
  25. custom_mmpkg/custom_mmcv/arraymisc/quantization.py +55 -0
  26. custom_mmpkg/custom_mmcv/cnn/__init__.py +41 -0
  27. custom_mmpkg/custom_mmcv/cnn/alexnet.py +61 -0
  28. custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py +35 -0
  29. custom_mmpkg/custom_mmcv/cnn/bricks/activation.py +92 -0
  30. custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py +125 -0
  31. custom_mmpkg/custom_mmcv/cnn/bricks/conv.py +44 -0
  32. custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py +62 -0
  33. custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py +206 -0
  34. custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py +148 -0
  35. custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py +96 -0
  36. custom_mmpkg/custom_mmcv/cnn/bricks/drop.py +65 -0
  37. custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py +412 -0
  38. custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py +34 -0
  39. custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py +29 -0
  40. custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py +306 -0
  41. custom_mmpkg/custom_mmcv/cnn/bricks/norm.py +144 -0
  42. custom_mmpkg/custom_mmcv/cnn/bricks/padding.py +36 -0
  43. custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py +88 -0
  44. custom_mmpkg/custom_mmcv/cnn/bricks/registry.py +16 -0
  45. custom_mmpkg/custom_mmcv/cnn/bricks/scale.py +21 -0
  46. custom_mmpkg/custom_mmcv/cnn/bricks/swish.py +25 -0
  47. custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py +595 -0
  48. custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py +84 -0
  49. custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py +180 -0
  50. custom_mmpkg/custom_mmcv/cnn/builder.py +30 -0
custom_midas_repo/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
custom_midas_repo/README.md ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
2
+
3
+ This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3):
4
+
5
+ >Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
6
+ René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun
7
+
8
+
9
+ and our [preprint](https://arxiv.org/abs/2103.13413):
10
+
11
+ > Vision Transformers for Dense Prediction
12
+ > René Ranftl, Alexey Bochkovskiy, Vladlen Koltun
13
+
14
+
15
+ MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with
16
+ multi-objective optimization.
17
+ The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2).
18
+ The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters.
19
+
20
+ ![](figures/Improvement_vs_FPS.png)
21
+
22
+ ### Setup
23
+
24
+ 1) Pick one or more models and download the corresponding weights to the `weights` folder:
25
+
26
+ MiDaS 3.1
27
+ - For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)
28
+ - For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)
29
+ - For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)
30
+ - For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin)
31
+
32
+ MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt)
33
+
34
+ MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt)
35
+
36
+ 1) Set up dependencies:
37
+
38
+ ```shell
39
+ conda env create -f environment.yaml
40
+ conda activate midas-py310
41
+ ```
42
+
43
+ #### optional
44
+
45
+ For the Next-ViT model, execute
46
+
47
+ ```shell
48
+ git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit
49
+ ```
50
+
51
+ For the OpenVINO model, install
52
+
53
+ ```shell
54
+ pip install openvino
55
+ ```
56
+
57
+ ### Usage
58
+
59
+ 1) Place one or more input images in the folder `input`.
60
+
61
+ 2) Run the model with
62
+
63
+ ```shell
64
+ python run.py --model_type <model_type> --input_path input --output_path output
65
+ ```
66
+ where ```<model_type>``` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type),
67
+ [dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type),
68
+ [dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type),
69
+ [dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type),
70
+ [midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type).
71
+
72
+ 3) The resulting depth maps are written to the `output` folder.
73
+
74
+ #### optional
75
+
76
+ 1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This
77
+ size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single
78
+ inference height but a range of different heights. Feel free to explore different heights by appending the extra
79
+ command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may
80
+ decrease the model accuracy.
81
+ 2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is
82
+ supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution,
83
+ disregarding the aspect ratio while preserving the height, use the command line argument `--square`.
84
+
85
+ #### via Camera
86
+
87
+ If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths
88
+ away and choose a model type as shown above:
89
+
90
+ ```shell
91
+ python run.py --model_type <model_type> --side
92
+ ```
93
+
94
+ The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown
95
+ side-by-side for comparison.
96
+
97
+ #### via Docker
98
+
99
+ 1) Make sure you have installed Docker and the
100
+ [NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)).
101
+
102
+ 2) Build the Docker image:
103
+
104
+ ```shell
105
+ docker build -t midas .
106
+ ```
107
+
108
+ 3) Run inference:
109
+
110
+ ```shell
111
+ docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas
112
+ ```
113
+
114
+ This command passes through all of your NVIDIA GPUs to the container, mounts the
115
+ `input` and `output` directories and then runs the inference.
116
+
117
+ #### via PyTorch Hub
118
+
119
+ The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/)
120
+
121
+ #### via TensorFlow or ONNX
122
+
123
+ See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory.
124
+
125
+ Currently only supports MiDaS v2.1.
126
+
127
+
128
+ #### via Mobile (iOS / Android)
129
+
130
+ See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory.
131
+
132
+ #### via ROS1 (Robot Operating System)
133
+
134
+ See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory.
135
+
136
+ Currently only supports MiDaS v2.1. DPT-based models to be added.
137
+
138
+
139
+ ### Accuracy
140
+
141
+ We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets
142
+ (see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**.
143
+ $\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to
144
+ MiDaS 3.0 DPT<sub>L-384</sub>. The models are grouped by the height used for inference, whereas the square training resolution is given by
145
+ the numbers in the model names. The table also shows the **number of parameters** (in millions) and the
146
+ **frames per second** for inference at the training resolution (for GPU RTX 3090):
147
+
148
+ | MiDaS Model | DIW </br><sup>WHDR</sup> | Eth3d </br><sup>AbsRel</sup> | Sintel </br><sup>AbsRel</sup> | TUM </br><sup>δ1</sup> | KITTI </br><sup>δ1</sup> | NYUv2 </br><sup>δ1</sup> | $\color{green}{\textsf{Imp.}}$ </br><sup>%</sup> | Par.</br><sup>M</sup> | FPS</br><sup>&nbsp;</sup> |
149
+ |-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:|
150
+ | **Inference height 512** | | | | | | | | | |
151
+ | [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** |
152
+ | [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** |
153
+ | | | | | | | | | | |
154
+ | **Inference height 384** | | | | | | | | | |
155
+ | [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 |
156
+ | [v3.1 Swin2<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 |
157
+ | [v3.1 Swin2<sub>B-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 |
158
+ | [v3.1 Swin<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 |
159
+ | [v3.1 BEiT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 |
160
+ | [v3.1 Next-ViT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 |
161
+ | [v3.1 BEiT<sub>B-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 |
162
+ | [v3.0 DPT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** |
163
+ | [v3.0 DPT<sub>H-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 |
164
+ | [v2.1 Large<sub>384</sub>](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 |
165
+ | | | | | | | | | | |
166
+ | **Inference height 256** | | | | | | | | | |
167
+ | [v3.1 Swin2<sub>T-256</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 |
168
+ | [v2.1 Small<sub>256</sub>](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** |
169
+ | | | | | | | | | | |
170
+ | **Inference height 224** | | | | | | | | | |
171
+ | [v3.1 LeViT<sub>224</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** |
172
+
173
+ &ast; No zero-shot error, because models are also trained on KITTI and NYU Depth V2\
174
+ $\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model
175
+ does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other
176
+ validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the
177
+ improvement, because these quantities are averages over the pixels of an image and do not take into account the
178
+ advantage of more details due to a higher resolution.\
179
+ Best values per column and same validation height in bold
180
+
181
+ #### Improvement
182
+
183
+ The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0
184
+ DPT<sub>L-384</sub> and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then
185
+ the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%.
186
+
187
+ Note that the improvements of 10% for MiDaS v2.0 &rarr; v2.1 and 21% for MiDaS v2.1 &rarr; v3.0 are not visible from the
188
+ improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large<sub>384</sub>
189
+ and v2.0 Large<sub>384</sub> respectively instead of v3.0 DPT<sub>L-384</sub>.
190
+
191
+ ### Depth map comparison
192
+
193
+ Zoom in for better visibility
194
+ ![](figures/Comparison.png)
195
+
196
+ ### Speed on Camera Feed
197
+
198
+ Test configuration
199
+ - Windows 10
200
+ - 11th Gen Intel Core i7-1185G7 3.00GHz
201
+ - 16GB RAM
202
+ - Camera resolution 640x480
203
+ - openvino_midas_v21_small_256
204
+
205
+ Speed: 22 FPS
206
+
207
+ ### Changelog
208
+
209
+ * [Dec 2022] Released MiDaS v3.1:
210
+ - New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf))
211
+ - Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split
212
+ - Best model, BEiT<sub>Large 512</sub>, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0
213
+ - Integrated live depth estimation from camera feed
214
+ * [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large).
215
+ * [Apr 2021] Released MiDaS v3.0:
216
+ - New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1
217
+ - Additional models can be found [here](https://github.com/isl-org/DPT)
218
+ * [Nov 2020] Released MiDaS v2.1:
219
+ - New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2)
220
+ - New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms.
221
+ - Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android)
222
+ - [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots
223
+ * [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/).
224
+ * [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust
225
+ * [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1))
226
+
227
+ ### Citation
228
+
229
+ Please cite our paper if you use this code or any of the models:
230
+ ```
231
+ @ARTICLE {Ranftl2022,
232
+ author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun",
233
+ title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer",
234
+ journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence",
235
+ year = "2022",
236
+ volume = "44",
237
+ number = "3"
238
+ }
239
+ ```
240
+
241
+ If you use a DPT-based model, please also cite:
242
+
243
+ ```
244
+ @article{Ranftl2021,
245
+ author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
246
+ title = {Vision Transformers for Dense Prediction},
247
+ journal = {ICCV},
248
+ year = {2021},
249
+ }
250
+ ```
251
+
252
+ ### Acknowledgements
253
+
254
+ Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT).
255
+ We'd like to thank the authors for making these libraries available.
256
+
257
+ ### License
258
+
259
+ MIT License
custom_midas_repo/__init__.py ADDED
File without changes
custom_midas_repo/hubconf.py ADDED
@@ -0,0 +1,435 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ["torch"]
2
+
3
+ import torch
4
+
5
+ from custom_midas_repo.midas.dpt_depth import DPTDepthModel
6
+ from custom_midas_repo.midas.midas_net import MidasNet
7
+ from custom_midas_repo.midas.midas_net_custom import MidasNet_small
8
+
9
+ def DPT_BEiT_L_512(pretrained=True, **kwargs):
10
+ """ # This docstring shows up in hub.help()
11
+ MiDaS DPT_BEiT_L_512 model for monocular depth estimation
12
+ pretrained (bool): load pretrained weights into model
13
+ """
14
+
15
+ model = DPTDepthModel(
16
+ path=None,
17
+ backbone="beitl16_512",
18
+ non_negative=True,
19
+ )
20
+
21
+ if pretrained:
22
+ checkpoint = (
23
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt"
24
+ )
25
+ state_dict = torch.hub.load_state_dict_from_url(
26
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
27
+ )
28
+ model.load_state_dict(state_dict)
29
+
30
+ return model
31
+
32
+ def DPT_BEiT_L_384(pretrained=True, **kwargs):
33
+ """ # This docstring shows up in hub.help()
34
+ MiDaS DPT_BEiT_L_384 model for monocular depth estimation
35
+ pretrained (bool): load pretrained weights into model
36
+ """
37
+
38
+ model = DPTDepthModel(
39
+ path=None,
40
+ backbone="beitl16_384",
41
+ non_negative=True,
42
+ )
43
+
44
+ if pretrained:
45
+ checkpoint = (
46
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt"
47
+ )
48
+ state_dict = torch.hub.load_state_dict_from_url(
49
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
50
+ )
51
+ model.load_state_dict(state_dict)
52
+
53
+ return model
54
+
55
+ def DPT_BEiT_B_384(pretrained=True, **kwargs):
56
+ """ # This docstring shows up in hub.help()
57
+ MiDaS DPT_BEiT_B_384 model for monocular depth estimation
58
+ pretrained (bool): load pretrained weights into model
59
+ """
60
+
61
+ model = DPTDepthModel(
62
+ path=None,
63
+ backbone="beitb16_384",
64
+ non_negative=True,
65
+ )
66
+
67
+ if pretrained:
68
+ checkpoint = (
69
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt"
70
+ )
71
+ state_dict = torch.hub.load_state_dict_from_url(
72
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
73
+ )
74
+ model.load_state_dict(state_dict)
75
+
76
+ return model
77
+
78
+ def DPT_SwinV2_L_384(pretrained=True, **kwargs):
79
+ """ # This docstring shows up in hub.help()
80
+ MiDaS DPT_SwinV2_L_384 model for monocular depth estimation
81
+ pretrained (bool): load pretrained weights into model
82
+ """
83
+
84
+ model = DPTDepthModel(
85
+ path=None,
86
+ backbone="swin2l24_384",
87
+ non_negative=True,
88
+ )
89
+
90
+ if pretrained:
91
+ checkpoint = (
92
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt"
93
+ )
94
+ state_dict = torch.hub.load_state_dict_from_url(
95
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
96
+ )
97
+ model.load_state_dict(state_dict)
98
+
99
+ return model
100
+
101
+ def DPT_SwinV2_B_384(pretrained=True, **kwargs):
102
+ """ # This docstring shows up in hub.help()
103
+ MiDaS DPT_SwinV2_B_384 model for monocular depth estimation
104
+ pretrained (bool): load pretrained weights into model
105
+ """
106
+
107
+ model = DPTDepthModel(
108
+ path=None,
109
+ backbone="swin2b24_384",
110
+ non_negative=True,
111
+ )
112
+
113
+ if pretrained:
114
+ checkpoint = (
115
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt"
116
+ )
117
+ state_dict = torch.hub.load_state_dict_from_url(
118
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
119
+ )
120
+ model.load_state_dict(state_dict)
121
+
122
+ return model
123
+
124
+ def DPT_SwinV2_T_256(pretrained=True, **kwargs):
125
+ """ # This docstring shows up in hub.help()
126
+ MiDaS DPT_SwinV2_T_256 model for monocular depth estimation
127
+ pretrained (bool): load pretrained weights into model
128
+ """
129
+
130
+ model = DPTDepthModel(
131
+ path=None,
132
+ backbone="swin2t16_256",
133
+ non_negative=True,
134
+ )
135
+
136
+ if pretrained:
137
+ checkpoint = (
138
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt"
139
+ )
140
+ state_dict = torch.hub.load_state_dict_from_url(
141
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
142
+ )
143
+ model.load_state_dict(state_dict)
144
+
145
+ return model
146
+
147
+ def DPT_Swin_L_384(pretrained=True, **kwargs):
148
+ """ # This docstring shows up in hub.help()
149
+ MiDaS DPT_Swin_L_384 model for monocular depth estimation
150
+ pretrained (bool): load pretrained weights into model
151
+ """
152
+
153
+ model = DPTDepthModel(
154
+ path=None,
155
+ backbone="swinl12_384",
156
+ non_negative=True,
157
+ )
158
+
159
+ if pretrained:
160
+ checkpoint = (
161
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt"
162
+ )
163
+ state_dict = torch.hub.load_state_dict_from_url(
164
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
165
+ )
166
+ model.load_state_dict(state_dict)
167
+
168
+ return model
169
+
170
+ def DPT_Next_ViT_L_384(pretrained=True, **kwargs):
171
+ """ # This docstring shows up in hub.help()
172
+ MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation
173
+ pretrained (bool): load pretrained weights into model
174
+ """
175
+
176
+ model = DPTDepthModel(
177
+ path=None,
178
+ backbone="next_vit_large_6m",
179
+ non_negative=True,
180
+ )
181
+
182
+ if pretrained:
183
+ checkpoint = (
184
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt"
185
+ )
186
+ state_dict = torch.hub.load_state_dict_from_url(
187
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
188
+ )
189
+ model.load_state_dict(state_dict)
190
+
191
+ return model
192
+
193
+ def DPT_LeViT_224(pretrained=True, **kwargs):
194
+ """ # This docstring shows up in hub.help()
195
+ MiDaS DPT_LeViT_224 model for monocular depth estimation
196
+ pretrained (bool): load pretrained weights into model
197
+ """
198
+
199
+ model = DPTDepthModel(
200
+ path=None,
201
+ backbone="levit_384",
202
+ non_negative=True,
203
+ head_features_1=64,
204
+ head_features_2=8,
205
+ )
206
+
207
+ if pretrained:
208
+ checkpoint = (
209
+ "https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt"
210
+ )
211
+ state_dict = torch.hub.load_state_dict_from_url(
212
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
213
+ )
214
+ model.load_state_dict(state_dict)
215
+
216
+ return model
217
+
218
+ def DPT_Large(pretrained=True, **kwargs):
219
+ """ # This docstring shows up in hub.help()
220
+ MiDaS DPT-Large model for monocular depth estimation
221
+ pretrained (bool): load pretrained weights into model
222
+ """
223
+
224
+ model = DPTDepthModel(
225
+ path=None,
226
+ backbone="vitl16_384",
227
+ non_negative=True,
228
+ )
229
+
230
+ if pretrained:
231
+ checkpoint = (
232
+ "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt"
233
+ )
234
+ state_dict = torch.hub.load_state_dict_from_url(
235
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
236
+ )
237
+ model.load_state_dict(state_dict)
238
+
239
+ return model
240
+
241
+ def DPT_Hybrid(pretrained=True, **kwargs):
242
+ """ # This docstring shows up in hub.help()
243
+ MiDaS DPT-Hybrid model for monocular depth estimation
244
+ pretrained (bool): load pretrained weights into model
245
+ """
246
+
247
+ model = DPTDepthModel(
248
+ path=None,
249
+ backbone="vitb_rn50_384",
250
+ non_negative=True,
251
+ )
252
+
253
+ if pretrained:
254
+ checkpoint = (
255
+ "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt"
256
+ )
257
+ state_dict = torch.hub.load_state_dict_from_url(
258
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
259
+ )
260
+ model.load_state_dict(state_dict)
261
+
262
+ return model
263
+
264
+ def MiDaS(pretrained=True, **kwargs):
265
+ """ # This docstring shows up in hub.help()
266
+ MiDaS v2.1 model for monocular depth estimation
267
+ pretrained (bool): load pretrained weights into model
268
+ """
269
+
270
+ model = MidasNet()
271
+
272
+ if pretrained:
273
+ checkpoint = (
274
+ "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt"
275
+ )
276
+ state_dict = torch.hub.load_state_dict_from_url(
277
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
278
+ )
279
+ model.load_state_dict(state_dict)
280
+
281
+ return model
282
+
283
+ def MiDaS_small(pretrained=True, **kwargs):
284
+ """ # This docstring shows up in hub.help()
285
+ MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices
286
+ pretrained (bool): load pretrained weights into model
287
+ """
288
+
289
+ model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True})
290
+
291
+ if pretrained:
292
+ checkpoint = (
293
+ "https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
294
+ )
295
+ state_dict = torch.hub.load_state_dict_from_url(
296
+ checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
297
+ )
298
+ model.load_state_dict(state_dict)
299
+
300
+ return model
301
+
302
+
303
+ def transforms():
304
+ import cv2
305
+ from torchvision.transforms import Compose
306
+ from custom_midas_repo.midas.transforms import Resize, NormalizeImage, PrepareForNet
307
+ from custom_midas_repo.midas import transforms
308
+
309
+ transforms.default_transform = Compose(
310
+ [
311
+ lambda img: {"image": img / 255.0},
312
+ Resize(
313
+ 384,
314
+ 384,
315
+ resize_target=None,
316
+ keep_aspect_ratio=True,
317
+ ensure_multiple_of=32,
318
+ resize_method="upper_bound",
319
+ image_interpolation_method=cv2.INTER_CUBIC,
320
+ ),
321
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
322
+ PrepareForNet(),
323
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
324
+ ]
325
+ )
326
+
327
+ transforms.small_transform = Compose(
328
+ [
329
+ lambda img: {"image": img / 255.0},
330
+ Resize(
331
+ 256,
332
+ 256,
333
+ resize_target=None,
334
+ keep_aspect_ratio=True,
335
+ ensure_multiple_of=32,
336
+ resize_method="upper_bound",
337
+ image_interpolation_method=cv2.INTER_CUBIC,
338
+ ),
339
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
340
+ PrepareForNet(),
341
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
342
+ ]
343
+ )
344
+
345
+ transforms.dpt_transform = Compose(
346
+ [
347
+ lambda img: {"image": img / 255.0},
348
+ Resize(
349
+ 384,
350
+ 384,
351
+ resize_target=None,
352
+ keep_aspect_ratio=True,
353
+ ensure_multiple_of=32,
354
+ resize_method="minimal",
355
+ image_interpolation_method=cv2.INTER_CUBIC,
356
+ ),
357
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
358
+ PrepareForNet(),
359
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
360
+ ]
361
+ )
362
+
363
+ transforms.beit512_transform = Compose(
364
+ [
365
+ lambda img: {"image": img / 255.0},
366
+ Resize(
367
+ 512,
368
+ 512,
369
+ resize_target=None,
370
+ keep_aspect_ratio=True,
371
+ ensure_multiple_of=32,
372
+ resize_method="minimal",
373
+ image_interpolation_method=cv2.INTER_CUBIC,
374
+ ),
375
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
376
+ PrepareForNet(),
377
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
378
+ ]
379
+ )
380
+
381
+ transforms.swin384_transform = Compose(
382
+ [
383
+ lambda img: {"image": img / 255.0},
384
+ Resize(
385
+ 384,
386
+ 384,
387
+ resize_target=None,
388
+ keep_aspect_ratio=False,
389
+ ensure_multiple_of=32,
390
+ resize_method="minimal",
391
+ image_interpolation_method=cv2.INTER_CUBIC,
392
+ ),
393
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
394
+ PrepareForNet(),
395
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
396
+ ]
397
+ )
398
+
399
+ transforms.swin256_transform = Compose(
400
+ [
401
+ lambda img: {"image": img / 255.0},
402
+ Resize(
403
+ 256,
404
+ 256,
405
+ resize_target=None,
406
+ keep_aspect_ratio=False,
407
+ ensure_multiple_of=32,
408
+ resize_method="minimal",
409
+ image_interpolation_method=cv2.INTER_CUBIC,
410
+ ),
411
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
412
+ PrepareForNet(),
413
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
414
+ ]
415
+ )
416
+
417
+ transforms.levit_transform = Compose(
418
+ [
419
+ lambda img: {"image": img / 255.0},
420
+ Resize(
421
+ 224,
422
+ 224,
423
+ resize_target=None,
424
+ keep_aspect_ratio=False,
425
+ ensure_multiple_of=32,
426
+ resize_method="minimal",
427
+ image_interpolation_method=cv2.INTER_CUBIC,
428
+ ),
429
+ NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
430
+ PrepareForNet(),
431
+ lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
432
+ ]
433
+ )
434
+
435
+ return transforms
custom_midas_repo/midas/__init__.py ADDED
File without changes
custom_midas_repo/midas/backbones/__init__.py ADDED
File without changes
custom_midas_repo/midas/backbones/beit.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import custom_timm as timm
2
+ import torch
3
+ import types
4
+
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ from .utils import forward_adapted_unflatten, make_backbone_default
9
+ from custom_timm.models.beit import gen_relative_position_index
10
+ from torch.utils.checkpoint import checkpoint
11
+ from typing import Optional
12
+
13
+
14
+ def forward_beit(pretrained, x):
15
+ return forward_adapted_unflatten(pretrained, x, "forward_features")
16
+
17
+
18
+ def patch_embed_forward(self, x):
19
+ """
20
+ Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
21
+ """
22
+ x = self.proj(x)
23
+ if self.flatten:
24
+ x = x.flatten(2).transpose(1, 2)
25
+ x = self.norm(x)
26
+ return x
27
+
28
+
29
+ def _get_rel_pos_bias(self, window_size):
30
+ """
31
+ Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
32
+ """
33
+ old_height = 2 * self.window_size[0] - 1
34
+ old_width = 2 * self.window_size[1] - 1
35
+
36
+ new_height = 2 * window_size[0] - 1
37
+ new_width = 2 * window_size[1] - 1
38
+
39
+ old_relative_position_bias_table = self.relative_position_bias_table
40
+
41
+ old_num_relative_distance = self.num_relative_distance
42
+ new_num_relative_distance = new_height * new_width + 3
43
+
44
+ old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
45
+
46
+ old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
47
+ new_sub_table = F.interpolate(old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear")
48
+ new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
49
+
50
+ new_relative_position_bias_table = torch.cat(
51
+ [new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
52
+
53
+ key = str(window_size[1]) + "," + str(window_size[0])
54
+ if key not in self.relative_position_indices.keys():
55
+ self.relative_position_indices[key] = gen_relative_position_index(window_size)
56
+
57
+ relative_position_bias = new_relative_position_bias_table[
58
+ self.relative_position_indices[key].view(-1)].view(
59
+ window_size[0] * window_size[1] + 1,
60
+ window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
61
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
62
+ return relative_position_bias.unsqueeze(0)
63
+
64
+
65
+ def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
66
+ """
67
+ Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
68
+ """
69
+ B, N, C = x.shape
70
+
71
+ qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
72
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
73
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
74
+ q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
75
+
76
+ q = q * self.scale
77
+ attn = (q @ k.transpose(-2, -1))
78
+
79
+ if self.relative_position_bias_table is not None:
80
+ window_size = tuple(np.array(resolution) // 16)
81
+ attn = attn + self._get_rel_pos_bias(window_size)
82
+ if shared_rel_pos_bias is not None:
83
+ attn = attn + shared_rel_pos_bias
84
+
85
+ attn = attn.softmax(dim=-1)
86
+ attn = self.attn_drop(attn)
87
+
88
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
89
+ x = self.proj(x)
90
+ x = self.proj_drop(x)
91
+ return x
92
+
93
+
94
+ def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
95
+ """
96
+ Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
97
+ """
98
+ if self.gamma_1 is None:
99
+ x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
100
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
101
+ else:
102
+ x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
103
+ shared_rel_pos_bias=shared_rel_pos_bias))
104
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
105
+ return x
106
+
107
+
108
+ def beit_forward_features(self, x):
109
+ """
110
+ Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
111
+ """
112
+ resolution = x.shape[2:]
113
+
114
+ x = self.patch_embed(x)
115
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
116
+ if self.pos_embed is not None:
117
+ x = x + self.pos_embed
118
+ x = self.pos_drop(x)
119
+
120
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
121
+ for blk in self.blocks:
122
+ if self.grad_checkpointing and not torch.jit.is_scripting():
123
+ x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
124
+ else:
125
+ x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
126
+ x = self.norm(x)
127
+ return x
128
+
129
+
130
+ def _make_beit_backbone(
131
+ model,
132
+ features=[96, 192, 384, 768],
133
+ size=[384, 384],
134
+ hooks=[0, 4, 8, 11],
135
+ vit_features=768,
136
+ use_readout="ignore",
137
+ start_index=1,
138
+ start_index_readout=1,
139
+ ):
140
+ backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
141
+ start_index_readout)
142
+
143
+ backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
144
+ backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
145
+
146
+ for block in backbone.model.blocks:
147
+ attn = block.attn
148
+ attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
149
+ attn.forward = types.MethodType(attention_forward, attn)
150
+ attn.relative_position_indices = {}
151
+
152
+ block.forward = types.MethodType(block_forward, block)
153
+
154
+ return backbone
155
+
156
+
157
+ def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
158
+ model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
159
+
160
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
161
+
162
+ features = [256, 512, 1024, 1024]
163
+
164
+ return _make_beit_backbone(
165
+ model,
166
+ features=features,
167
+ size=[512, 512],
168
+ hooks=hooks,
169
+ vit_features=1024,
170
+ use_readout=use_readout,
171
+ )
172
+
173
+
174
+ def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
175
+ model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
176
+
177
+ hooks = [5, 11, 17, 23] if hooks is None else hooks
178
+ return _make_beit_backbone(
179
+ model,
180
+ features=[256, 512, 1024, 1024],
181
+ hooks=hooks,
182
+ vit_features=1024,
183
+ use_readout=use_readout,
184
+ )
185
+
186
+
187
+ def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
188
+ model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
189
+
190
+ hooks = [2, 5, 8, 11] if hooks is None else hooks
191
+ return _make_beit_backbone(
192
+ model,
193
+ features=[96, 192, 384, 768],
194
+ hooks=hooks,
195
+ use_readout=use_readout,
196
+ )
custom_midas_repo/midas/backbones/levit.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import custom_timm as timm
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .utils import activations, get_activation, Transpose
7
+
8
+
9
+ def forward_levit(pretrained, x):
10
+ pretrained.model.forward_features(x)
11
+
12
+ layer_1 = pretrained.activations["1"]
13
+ layer_2 = pretrained.activations["2"]
14
+ layer_3 = pretrained.activations["3"]
15
+
16
+ layer_1 = pretrained.act_postprocess1(layer_1)
17
+ layer_2 = pretrained.act_postprocess2(layer_2)
18
+ layer_3 = pretrained.act_postprocess3(layer_3)
19
+
20
+ return layer_1, layer_2, layer_3
21
+
22
+
23
+ def _make_levit_backbone(
24
+ model,
25
+ hooks=[3, 11, 21],
26
+ patch_grid=[14, 14]
27
+ ):
28
+ pretrained = nn.Module()
29
+
30
+ pretrained.model = model
31
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
32
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
33
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
34
+
35
+ pretrained.activations = activations
36
+
37
+ patch_grid_size = np.array(patch_grid, dtype=int)
38
+
39
+ pretrained.act_postprocess1 = nn.Sequential(
40
+ Transpose(1, 2),
41
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
42
+ )
43
+ pretrained.act_postprocess2 = nn.Sequential(
44
+ Transpose(1, 2),
45
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
46
+ )
47
+ pretrained.act_postprocess3 = nn.Sequential(
48
+ Transpose(1, 2),
49
+ nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
50
+ )
51
+
52
+ return pretrained
53
+
54
+
55
+ class ConvTransposeNorm(nn.Sequential):
56
+ """
57
+ Modification of
58
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
59
+ such that ConvTranspose2d is used instead of Conv2d.
60
+ """
61
+
62
+ def __init__(
63
+ self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
64
+ groups=1, bn_weight_init=1):
65
+ super().__init__()
66
+ self.add_module('c',
67
+ nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
68
+ self.add_module('bn', nn.BatchNorm2d(out_chs))
69
+
70
+ nn.init.constant_(self.bn.weight, bn_weight_init)
71
+
72
+ @torch.no_grad()
73
+ def fuse(self):
74
+ c, bn = self._modules.values()
75
+ w = bn.weight / (bn.running_var + bn.eps) ** 0.5
76
+ w = c.weight * w[:, None, None, None]
77
+ b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
78
+ m = nn.ConvTranspose2d(
79
+ w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
80
+ padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
81
+ m.weight.data.copy_(w)
82
+ m.bias.data.copy_(b)
83
+ return m
84
+
85
+
86
+ def stem_b4_transpose(in_chs, out_chs, activation):
87
+ """
88
+ Modification of
89
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
90
+ such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
91
+ """
92
+ return nn.Sequential(
93
+ ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
94
+ activation(),
95
+ ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
96
+ activation())
97
+
98
+
99
+ def _make_pretrained_levit_384(pretrained, hooks=None):
100
+ model = timm.create_model("levit_384", pretrained=pretrained)
101
+
102
+ hooks = [3, 11, 21] if hooks == None else hooks
103
+ return _make_levit_backbone(
104
+ model,
105
+ hooks=hooks
106
+ )
custom_midas_repo/midas/backbones/next_vit.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import custom_timm as timm
2
+
3
+ import torch.nn as nn
4
+
5
+ from pathlib import Path
6
+ from .utils import activations, forward_default, get_activation
7
+
8
+ from ..external.next_vit.classification.nextvit import *
9
+
10
+
11
+ def forward_next_vit(pretrained, x):
12
+ return forward_default(pretrained, x, "forward")
13
+
14
+
15
+ def _make_next_vit_backbone(
16
+ model,
17
+ hooks=[2, 6, 36, 39],
18
+ ):
19
+ pretrained = nn.Module()
20
+
21
+ pretrained.model = model
22
+ pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
23
+ pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
24
+ pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
25
+ pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
26
+
27
+ pretrained.activations = activations
28
+
29
+ return pretrained
30
+
31
+
32
+ def _make_pretrained_next_vit_large_6m(hooks=None):
33
+ model = timm.create_model("nextvit_large")
34
+
35
+ hooks = [2, 6, 36, 39] if hooks == None else hooks
36
+ return _make_next_vit_backbone(
37
+ model,
38
+ hooks=hooks,
39
+ )
custom_midas_repo/midas/backbones/swin.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import custom_timm as timm
2
+
3
+ from .swin_common import _make_swin_backbone
4
+
5
+
6
+ def _make_pretrained_swinl12_384(pretrained, hooks=None):
7
+ model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
8
+
9
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
10
+ return _make_swin_backbone(
11
+ model,
12
+ hooks=hooks
13
+ )
custom_midas_repo/midas/backbones/swin2.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import custom_timm as timm
2
+
3
+ from .swin_common import _make_swin_backbone
4
+
5
+
6
+ def _make_pretrained_swin2l24_384(pretrained, hooks=None):
7
+ model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
8
+
9
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
10
+ return _make_swin_backbone(
11
+ model,
12
+ hooks=hooks
13
+ )
14
+
15
+
16
+ def _make_pretrained_swin2b24_384(pretrained, hooks=None):
17
+ model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
18
+
19
+ hooks = [1, 1, 17, 1] if hooks == None else hooks
20
+ return _make_swin_backbone(
21
+ model,
22
+ hooks=hooks
23
+ )
24
+
25
+
26
+ def _make_pretrained_swin2t16_256(pretrained, hooks=None):
27
+ model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
28
+
29
+ hooks = [1, 1, 5, 1] if hooks == None else hooks
30
+ return _make_swin_backbone(
31
+ model,
32
+ hooks=hooks,
33
+ patch_grid=[64, 64]
34
+ )
custom_midas_repo/midas/backbones/swin_common.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+ import numpy as np
5
+
6
+ from .utils import activations, forward_default, get_activation, Transpose
7
+
8
+
9
+ def forward_swin(pretrained, x):
10
+ return forward_default(pretrained, x)
11
+
12
+
13
+ def _make_swin_backbone(
14
+ model,
15
+ hooks=[1, 1, 17, 1],
16
+ patch_grid=[96, 96]
17
+ ):
18
+ pretrained = nn.Module()
19
+
20
+ pretrained.model = model
21
+ pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
22
+ pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
23
+ pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
24
+ pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
25
+
26
+ pretrained.activations = activations
27
+
28
+ if hasattr(model, "patch_grid"):
29
+ used_patch_grid = model.patch_grid
30
+ else:
31
+ used_patch_grid = patch_grid
32
+
33
+ patch_grid_size = np.array(used_patch_grid, dtype=int)
34
+
35
+ pretrained.act_postprocess1 = nn.Sequential(
36
+ Transpose(1, 2),
37
+ nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
38
+ )
39
+ pretrained.act_postprocess2 = nn.Sequential(
40
+ Transpose(1, 2),
41
+ nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
42
+ )
43
+ pretrained.act_postprocess3 = nn.Sequential(
44
+ Transpose(1, 2),
45
+ nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
46
+ )
47
+ pretrained.act_postprocess4 = nn.Sequential(
48
+ Transpose(1, 2),
49
+ nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
50
+ )
51
+
52
+ return pretrained
custom_midas_repo/midas/backbones/utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class Slice(nn.Module):
7
+ def __init__(self, start_index=1):
8
+ super(Slice, self).__init__()
9
+ self.start_index = start_index
10
+
11
+ def forward(self, x):
12
+ return x[:, self.start_index:]
13
+
14
+
15
+ class AddReadout(nn.Module):
16
+ def __init__(self, start_index=1):
17
+ super(AddReadout, self).__init__()
18
+ self.start_index = start_index
19
+
20
+ def forward(self, x):
21
+ if self.start_index == 2:
22
+ readout = (x[:, 0] + x[:, 1]) / 2
23
+ else:
24
+ readout = x[:, 0]
25
+ return x[:, self.start_index:] + readout.unsqueeze(1)
26
+
27
+
28
+ class ProjectReadout(nn.Module):
29
+ def __init__(self, in_features, start_index=1):
30
+ super(ProjectReadout, self).__init__()
31
+ self.start_index = start_index
32
+
33
+ self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
34
+
35
+ def forward(self, x):
36
+ readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
37
+ features = torch.cat((x[:, self.start_index:], readout), -1)
38
+
39
+ return self.project(features)
40
+
41
+
42
+ class Transpose(nn.Module):
43
+ def __init__(self, dim0, dim1):
44
+ super(Transpose, self).__init__()
45
+ self.dim0 = dim0
46
+ self.dim1 = dim1
47
+
48
+ def forward(self, x):
49
+ x = x.transpose(self.dim0, self.dim1)
50
+ return x
51
+
52
+
53
+ activations = {}
54
+
55
+
56
+ def get_activation(name):
57
+ def hook(model, input, output):
58
+ activations[name] = output
59
+
60
+ return hook
61
+
62
+
63
+ def forward_default(pretrained, x, function_name="forward_features"):
64
+ exec(f"pretrained.model.{function_name}(x)")
65
+
66
+ layer_1 = pretrained.activations["1"]
67
+ layer_2 = pretrained.activations["2"]
68
+ layer_3 = pretrained.activations["3"]
69
+ layer_4 = pretrained.activations["4"]
70
+
71
+ if hasattr(pretrained, "act_postprocess1"):
72
+ layer_1 = pretrained.act_postprocess1(layer_1)
73
+ if hasattr(pretrained, "act_postprocess2"):
74
+ layer_2 = pretrained.act_postprocess2(layer_2)
75
+ if hasattr(pretrained, "act_postprocess3"):
76
+ layer_3 = pretrained.act_postprocess3(layer_3)
77
+ if hasattr(pretrained, "act_postprocess4"):
78
+ layer_4 = pretrained.act_postprocess4(layer_4)
79
+
80
+ return layer_1, layer_2, layer_3, layer_4
81
+
82
+
83
+ def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
84
+ b, c, h, w = x.shape
85
+
86
+ exec(f"glob = pretrained.model.{function_name}(x)")
87
+
88
+ layer_1 = pretrained.activations["1"]
89
+ layer_2 = pretrained.activations["2"]
90
+ layer_3 = pretrained.activations["3"]
91
+ layer_4 = pretrained.activations["4"]
92
+
93
+ layer_1 = pretrained.act_postprocess1[0:2](layer_1)
94
+ layer_2 = pretrained.act_postprocess2[0:2](layer_2)
95
+ layer_3 = pretrained.act_postprocess3[0:2](layer_3)
96
+ layer_4 = pretrained.act_postprocess4[0:2](layer_4)
97
+
98
+ unflatten = nn.Sequential(
99
+ nn.Unflatten(
100
+ 2,
101
+ torch.Size(
102
+ [
103
+ h // pretrained.model.patch_size[1],
104
+ w // pretrained.model.patch_size[0],
105
+ ]
106
+ ),
107
+ )
108
+ )
109
+
110
+ if layer_1.ndim == 3:
111
+ layer_1 = unflatten(layer_1)
112
+ if layer_2.ndim == 3:
113
+ layer_2 = unflatten(layer_2)
114
+ if layer_3.ndim == 3:
115
+ layer_3 = unflatten(layer_3)
116
+ if layer_4.ndim == 3:
117
+ layer_4 = unflatten(layer_4)
118
+
119
+ layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
120
+ layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
121
+ layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
122
+ layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
123
+
124
+ return layer_1, layer_2, layer_3, layer_4
125
+
126
+
127
+ def get_readout_oper(vit_features, features, use_readout, start_index=1):
128
+ if use_readout == "ignore":
129
+ readout_oper = [Slice(start_index)] * len(features)
130
+ elif use_readout == "add":
131
+ readout_oper = [AddReadout(start_index)] * len(features)
132
+ elif use_readout == "project":
133
+ readout_oper = [
134
+ ProjectReadout(vit_features, start_index) for out_feat in features
135
+ ]
136
+ else:
137
+ assert (
138
+ False
139
+ ), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
140
+
141
+ return readout_oper
142
+
143
+
144
+ def make_backbone_default(
145
+ model,
146
+ features=[96, 192, 384, 768],
147
+ size=[384, 384],
148
+ hooks=[2, 5, 8, 11],
149
+ vit_features=768,
150
+ use_readout="ignore",
151
+ start_index=1,
152
+ start_index_readout=1,
153
+ ):
154
+ pretrained = nn.Module()
155
+
156
+ pretrained.model = model
157
+ pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
158
+ pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
159
+ pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
160
+ pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
161
+
162
+ pretrained.activations = activations
163
+
164
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
165
+
166
+ # 32, 48, 136, 384
167
+ pretrained.act_postprocess1 = nn.Sequential(
168
+ readout_oper[0],
169
+ Transpose(1, 2),
170
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
171
+ nn.Conv2d(
172
+ in_channels=vit_features,
173
+ out_channels=features[0],
174
+ kernel_size=1,
175
+ stride=1,
176
+ padding=0,
177
+ ),
178
+ nn.ConvTranspose2d(
179
+ in_channels=features[0],
180
+ out_channels=features[0],
181
+ kernel_size=4,
182
+ stride=4,
183
+ padding=0,
184
+ bias=True,
185
+ dilation=1,
186
+ groups=1,
187
+ ),
188
+ )
189
+
190
+ pretrained.act_postprocess2 = nn.Sequential(
191
+ readout_oper[1],
192
+ Transpose(1, 2),
193
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
194
+ nn.Conv2d(
195
+ in_channels=vit_features,
196
+ out_channels=features[1],
197
+ kernel_size=1,
198
+ stride=1,
199
+ padding=0,
200
+ ),
201
+ nn.ConvTranspose2d(
202
+ in_channels=features[1],
203
+ out_channels=features[1],
204
+ kernel_size=2,
205
+ stride=2,
206
+ padding=0,
207
+ bias=True,
208
+ dilation=1,
209
+ groups=1,
210
+ ),
211
+ )
212
+
213
+ pretrained.act_postprocess3 = nn.Sequential(
214
+ readout_oper[2],
215
+ Transpose(1, 2),
216
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
217
+ nn.Conv2d(
218
+ in_channels=vit_features,
219
+ out_channels=features[2],
220
+ kernel_size=1,
221
+ stride=1,
222
+ padding=0,
223
+ ),
224
+ )
225
+
226
+ pretrained.act_postprocess4 = nn.Sequential(
227
+ readout_oper[3],
228
+ Transpose(1, 2),
229
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
230
+ nn.Conv2d(
231
+ in_channels=vit_features,
232
+ out_channels=features[3],
233
+ kernel_size=1,
234
+ stride=1,
235
+ padding=0,
236
+ ),
237
+ nn.Conv2d(
238
+ in_channels=features[3],
239
+ out_channels=features[3],
240
+ kernel_size=3,
241
+ stride=2,
242
+ padding=1,
243
+ ),
244
+ )
245
+
246
+ pretrained.model.start_index = start_index
247
+ pretrained.model.patch_size = [16, 16]
248
+
249
+ return pretrained
custom_midas_repo/midas/backbones/vit.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import custom_timm as timm
4
+ import types
5
+ import math
6
+ import torch.nn.functional as F
7
+
8
+ from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
9
+ make_backbone_default, Transpose)
10
+
11
+
12
+ def forward_vit(pretrained, x):
13
+ return forward_adapted_unflatten(pretrained, x, "forward_flex")
14
+
15
+
16
+ def _resize_pos_embed(self, posemb, gs_h, gs_w):
17
+ posemb_tok, posemb_grid = (
18
+ posemb[:, : self.start_index],
19
+ posemb[0, self.start_index:],
20
+ )
21
+
22
+ gs_old = int(math.sqrt(len(posemb_grid)))
23
+
24
+ posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
25
+ posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
26
+ posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
27
+
28
+ posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
29
+
30
+ return posemb
31
+
32
+
33
+ def forward_flex(self, x):
34
+ b, c, h, w = x.shape
35
+
36
+ pos_embed = self._resize_pos_embed(
37
+ self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
38
+ )
39
+
40
+ B = x.shape[0]
41
+
42
+ if hasattr(self.patch_embed, "backbone"):
43
+ x = self.patch_embed.backbone(x)
44
+ if isinstance(x, (list, tuple)):
45
+ x = x[-1] # last feature if backbone outputs list/tuple of features
46
+
47
+ x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
48
+
49
+ if getattr(self, "dist_token", None) is not None:
50
+ cls_tokens = self.cls_token.expand(
51
+ B, -1, -1
52
+ ) # stole cls_tokens impl from Phil Wang, thanks
53
+ dist_token = self.dist_token.expand(B, -1, -1)
54
+ x = torch.cat((cls_tokens, dist_token, x), dim=1)
55
+ else:
56
+ if self.no_embed_class:
57
+ x = x + pos_embed
58
+ cls_tokens = self.cls_token.expand(
59
+ B, -1, -1
60
+ ) # stole cls_tokens impl from Phil Wang, thanks
61
+ x = torch.cat((cls_tokens, x), dim=1)
62
+
63
+ if not self.no_embed_class:
64
+ x = x + pos_embed
65
+ x = self.pos_drop(x)
66
+
67
+ for blk in self.blocks:
68
+ x = blk(x)
69
+
70
+ x = self.norm(x)
71
+
72
+ return x
73
+
74
+
75
+ def _make_vit_b16_backbone(
76
+ model,
77
+ features=[96, 192, 384, 768],
78
+ size=[384, 384],
79
+ hooks=[2, 5, 8, 11],
80
+ vit_features=768,
81
+ use_readout="ignore",
82
+ start_index=1,
83
+ start_index_readout=1,
84
+ ):
85
+ pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
86
+ start_index_readout)
87
+
88
+ # We inject this function into the VisionTransformer instances so that
89
+ # we can use it with interpolated position embeddings without modifying the library source.
90
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
91
+ pretrained.model._resize_pos_embed = types.MethodType(
92
+ _resize_pos_embed, pretrained.model
93
+ )
94
+
95
+ return pretrained
96
+
97
+
98
+ def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
99
+ model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
100
+
101
+ hooks = [5, 11, 17, 23] if hooks == None else hooks
102
+ return _make_vit_b16_backbone(
103
+ model,
104
+ features=[256, 512, 1024, 1024],
105
+ hooks=hooks,
106
+ vit_features=1024,
107
+ use_readout=use_readout,
108
+ )
109
+
110
+
111
+ def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
112
+ model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
113
+
114
+ hooks = [2, 5, 8, 11] if hooks == None else hooks
115
+ return _make_vit_b16_backbone(
116
+ model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
117
+ )
118
+
119
+
120
+ def _make_vit_b_rn50_backbone(
121
+ model,
122
+ features=[256, 512, 768, 768],
123
+ size=[384, 384],
124
+ hooks=[0, 1, 8, 11],
125
+ vit_features=768,
126
+ patch_size=[16, 16],
127
+ number_stages=2,
128
+ use_vit_only=False,
129
+ use_readout="ignore",
130
+ start_index=1,
131
+ ):
132
+ pretrained = nn.Module()
133
+
134
+ pretrained.model = model
135
+
136
+ used_number_stages = 0 if use_vit_only else number_stages
137
+ for s in range(used_number_stages):
138
+ pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
139
+ get_activation(str(s + 1))
140
+ )
141
+ for s in range(used_number_stages, 4):
142
+ pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
143
+
144
+ pretrained.activations = activations
145
+
146
+ readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
147
+
148
+ for s in range(used_number_stages):
149
+ value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
150
+ exec(f"pretrained.act_postprocess{s + 1}=value")
151
+ for s in range(used_number_stages, 4):
152
+ if s < number_stages:
153
+ final_layer = nn.ConvTranspose2d(
154
+ in_channels=features[s],
155
+ out_channels=features[s],
156
+ kernel_size=4 // (2 ** s),
157
+ stride=4 // (2 ** s),
158
+ padding=0,
159
+ bias=True,
160
+ dilation=1,
161
+ groups=1,
162
+ )
163
+ elif s > number_stages:
164
+ final_layer = nn.Conv2d(
165
+ in_channels=features[3],
166
+ out_channels=features[3],
167
+ kernel_size=3,
168
+ stride=2,
169
+ padding=1,
170
+ )
171
+ else:
172
+ final_layer = None
173
+
174
+ layers = [
175
+ readout_oper[s],
176
+ Transpose(1, 2),
177
+ nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
178
+ nn.Conv2d(
179
+ in_channels=vit_features,
180
+ out_channels=features[s],
181
+ kernel_size=1,
182
+ stride=1,
183
+ padding=0,
184
+ ),
185
+ ]
186
+ if final_layer is not None:
187
+ layers.append(final_layer)
188
+
189
+ value = nn.Sequential(*layers)
190
+ exec(f"pretrained.act_postprocess{s + 1}=value")
191
+
192
+ pretrained.model.start_index = start_index
193
+ pretrained.model.patch_size = patch_size
194
+
195
+ # We inject this function into the VisionTransformer instances so that
196
+ # we can use it with interpolated position embeddings without modifying the library source.
197
+ pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
198
+
199
+ # We inject this function into the VisionTransformer instances so that
200
+ # we can use it with interpolated position embeddings without modifying the library source.
201
+ pretrained.model._resize_pos_embed = types.MethodType(
202
+ _resize_pos_embed, pretrained.model
203
+ )
204
+
205
+ return pretrained
206
+
207
+
208
+ def _make_pretrained_vitb_rn50_384(
209
+ pretrained, use_readout="ignore", hooks=None, use_vit_only=False
210
+ ):
211
+ model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
212
+
213
+ hooks = [0, 1, 8, 11] if hooks == None else hooks
214
+ return _make_vit_b_rn50_backbone(
215
+ model,
216
+ features=[256, 512, 768, 768],
217
+ size=[384, 384],
218
+ hooks=hooks,
219
+ use_vit_only=use_vit_only,
220
+ use_readout=use_readout,
221
+ )
custom_midas_repo/midas/base_model.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class BaseModel(torch.nn.Module):
5
+ def load(self, path):
6
+ """Load model from file.
7
+
8
+ Args:
9
+ path (str): file path
10
+ """
11
+ parameters = torch.load(path, map_location=torch.device('cpu'))
12
+
13
+ if "optimizer" in parameters:
14
+ parameters = parameters["model"]
15
+
16
+ self.load_state_dict(parameters)
custom_midas_repo/midas/blocks.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .backbones.beit import (
5
+ _make_pretrained_beitl16_512,
6
+ _make_pretrained_beitl16_384,
7
+ _make_pretrained_beitb16_384,
8
+ forward_beit,
9
+ )
10
+ from .backbones.swin_common import (
11
+ forward_swin,
12
+ )
13
+ from .backbones.swin2 import (
14
+ _make_pretrained_swin2l24_384,
15
+ _make_pretrained_swin2b24_384,
16
+ _make_pretrained_swin2t16_256,
17
+ )
18
+ from .backbones.swin import (
19
+ _make_pretrained_swinl12_384,
20
+ )
21
+ from .backbones.levit import (
22
+ _make_pretrained_levit_384,
23
+ forward_levit,
24
+ )
25
+ from .backbones.vit import (
26
+ _make_pretrained_vitb_rn50_384,
27
+ _make_pretrained_vitl16_384,
28
+ _make_pretrained_vitb16_384,
29
+ forward_vit,
30
+ )
31
+
32
+ def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
33
+ use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
34
+ if backbone == "beitl16_512":
35
+ pretrained = _make_pretrained_beitl16_512(
36
+ use_pretrained, hooks=hooks, use_readout=use_readout
37
+ )
38
+ scratch = _make_scratch(
39
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
40
+ ) # BEiT_512-L (backbone)
41
+ elif backbone == "beitl16_384":
42
+ pretrained = _make_pretrained_beitl16_384(
43
+ use_pretrained, hooks=hooks, use_readout=use_readout
44
+ )
45
+ scratch = _make_scratch(
46
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
47
+ ) # BEiT_384-L (backbone)
48
+ elif backbone == "beitb16_384":
49
+ pretrained = _make_pretrained_beitb16_384(
50
+ use_pretrained, hooks=hooks, use_readout=use_readout
51
+ )
52
+ scratch = _make_scratch(
53
+ [96, 192, 384, 768], features, groups=groups, expand=expand
54
+ ) # BEiT_384-B (backbone)
55
+ elif backbone == "swin2l24_384":
56
+ pretrained = _make_pretrained_swin2l24_384(
57
+ use_pretrained, hooks=hooks
58
+ )
59
+ scratch = _make_scratch(
60
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
61
+ ) # Swin2-L/12to24 (backbone)
62
+ elif backbone == "swin2b24_384":
63
+ pretrained = _make_pretrained_swin2b24_384(
64
+ use_pretrained, hooks=hooks
65
+ )
66
+ scratch = _make_scratch(
67
+ [128, 256, 512, 1024], features, groups=groups, expand=expand
68
+ ) # Swin2-B/12to24 (backbone)
69
+ elif backbone == "swin2t16_256":
70
+ pretrained = _make_pretrained_swin2t16_256(
71
+ use_pretrained, hooks=hooks
72
+ )
73
+ scratch = _make_scratch(
74
+ [96, 192, 384, 768], features, groups=groups, expand=expand
75
+ ) # Swin2-T/16 (backbone)
76
+ elif backbone == "swinl12_384":
77
+ pretrained = _make_pretrained_swinl12_384(
78
+ use_pretrained, hooks=hooks
79
+ )
80
+ scratch = _make_scratch(
81
+ [192, 384, 768, 1536], features, groups=groups, expand=expand
82
+ ) # Swin-L/12 (backbone)
83
+ elif backbone == "next_vit_large_6m":
84
+ from .backbones.next_vit import _make_pretrained_next_vit_large_6m
85
+ pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
86
+ scratch = _make_scratch(
87
+ in_features, features, groups=groups, expand=expand
88
+ ) # Next-ViT-L on ImageNet-1K-6M (backbone)
89
+ elif backbone == "levit_384":
90
+ pretrained = _make_pretrained_levit_384(
91
+ use_pretrained, hooks=hooks
92
+ )
93
+ scratch = _make_scratch(
94
+ [384, 512, 768], features, groups=groups, expand=expand
95
+ ) # LeViT 384 (backbone)
96
+ elif backbone == "vitl16_384":
97
+ pretrained = _make_pretrained_vitl16_384(
98
+ use_pretrained, hooks=hooks, use_readout=use_readout
99
+ )
100
+ scratch = _make_scratch(
101
+ [256, 512, 1024, 1024], features, groups=groups, expand=expand
102
+ ) # ViT-L/16 - 85.0% Top1 (backbone)
103
+ elif backbone == "vitb_rn50_384":
104
+ pretrained = _make_pretrained_vitb_rn50_384(
105
+ use_pretrained,
106
+ hooks=hooks,
107
+ use_vit_only=use_vit_only,
108
+ use_readout=use_readout,
109
+ )
110
+ scratch = _make_scratch(
111
+ [256, 512, 768, 768], features, groups=groups, expand=expand
112
+ ) # ViT-H/16 - 85.0% Top1 (backbone)
113
+ elif backbone == "vitb16_384":
114
+ pretrained = _make_pretrained_vitb16_384(
115
+ use_pretrained, hooks=hooks, use_readout=use_readout
116
+ )
117
+ scratch = _make_scratch(
118
+ [96, 192, 384, 768], features, groups=groups, expand=expand
119
+ ) # ViT-B/16 - 84.6% Top1 (backbone)
120
+ elif backbone == "resnext101_wsl":
121
+ pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
122
+ scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
123
+ elif backbone == "efficientnet_lite3":
124
+ pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
125
+ scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
126
+ else:
127
+ print(f"Backbone '{backbone}' not implemented")
128
+ assert False
129
+
130
+ return pretrained, scratch
131
+
132
+
133
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
134
+ scratch = nn.Module()
135
+
136
+ out_shape1 = out_shape
137
+ out_shape2 = out_shape
138
+ out_shape3 = out_shape
139
+ if len(in_shape) >= 4:
140
+ out_shape4 = out_shape
141
+
142
+ if expand:
143
+ out_shape1 = out_shape
144
+ out_shape2 = out_shape*2
145
+ out_shape3 = out_shape*4
146
+ if len(in_shape) >= 4:
147
+ out_shape4 = out_shape*8
148
+
149
+ scratch.layer1_rn = nn.Conv2d(
150
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
151
+ )
152
+ scratch.layer2_rn = nn.Conv2d(
153
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
154
+ )
155
+ scratch.layer3_rn = nn.Conv2d(
156
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
157
+ )
158
+ if len(in_shape) >= 4:
159
+ scratch.layer4_rn = nn.Conv2d(
160
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
161
+ )
162
+
163
+ return scratch
164
+
165
+
166
+ def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
167
+ efficientnet = torch.hub.load(
168
+ "rwightman/gen-efficientnet-pytorch",
169
+ "tf_efficientnet_lite3",
170
+ pretrained=use_pretrained,
171
+ exportable=exportable
172
+ )
173
+ return _make_efficientnet_backbone(efficientnet)
174
+
175
+
176
+ def _make_efficientnet_backbone(effnet):
177
+ pretrained = nn.Module()
178
+
179
+ pretrained.layer1 = nn.Sequential(
180
+ effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
181
+ )
182
+ pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
183
+ pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
184
+ pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
185
+
186
+ return pretrained
187
+
188
+
189
+ def _make_resnet_backbone(resnet):
190
+ pretrained = nn.Module()
191
+ pretrained.layer1 = nn.Sequential(
192
+ resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
193
+ )
194
+
195
+ pretrained.layer2 = resnet.layer2
196
+ pretrained.layer3 = resnet.layer3
197
+ pretrained.layer4 = resnet.layer4
198
+
199
+ return pretrained
200
+
201
+
202
+ def _make_pretrained_resnext101_wsl(use_pretrained):
203
+ resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
204
+ return _make_resnet_backbone(resnet)
205
+
206
+
207
+
208
+ class Interpolate(nn.Module):
209
+ """Interpolation module.
210
+ """
211
+
212
+ def __init__(self, scale_factor, mode, align_corners=False):
213
+ """Init.
214
+
215
+ Args:
216
+ scale_factor (float): scaling
217
+ mode (str): interpolation mode
218
+ """
219
+ super(Interpolate, self).__init__()
220
+
221
+ self.interp = nn.functional.interpolate
222
+ self.scale_factor = scale_factor
223
+ self.mode = mode
224
+ self.align_corners = align_corners
225
+
226
+ def forward(self, x):
227
+ """Forward pass.
228
+
229
+ Args:
230
+ x (tensor): input
231
+
232
+ Returns:
233
+ tensor: interpolated data
234
+ """
235
+
236
+ x = self.interp(
237
+ x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
238
+ )
239
+
240
+ return x
241
+
242
+
243
+ class ResidualConvUnit(nn.Module):
244
+ """Residual convolution module.
245
+ """
246
+
247
+ def __init__(self, features):
248
+ """Init.
249
+
250
+ Args:
251
+ features (int): number of features
252
+ """
253
+ super().__init__()
254
+
255
+ self.conv1 = nn.Conv2d(
256
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
257
+ )
258
+
259
+ self.conv2 = nn.Conv2d(
260
+ features, features, kernel_size=3, stride=1, padding=1, bias=True
261
+ )
262
+
263
+ self.relu = nn.ReLU(inplace=True)
264
+
265
+ def forward(self, x):
266
+ """Forward pass.
267
+
268
+ Args:
269
+ x (tensor): input
270
+
271
+ Returns:
272
+ tensor: output
273
+ """
274
+ out = self.relu(x)
275
+ out = self.conv1(out)
276
+ out = self.relu(out)
277
+ out = self.conv2(out)
278
+
279
+ return out + x
280
+
281
+
282
+ class FeatureFusionBlock(nn.Module):
283
+ """Feature fusion block.
284
+ """
285
+
286
+ def __init__(self, features):
287
+ """Init.
288
+
289
+ Args:
290
+ features (int): number of features
291
+ """
292
+ super(FeatureFusionBlock, self).__init__()
293
+
294
+ self.resConfUnit1 = ResidualConvUnit(features)
295
+ self.resConfUnit2 = ResidualConvUnit(features)
296
+
297
+ def forward(self, *xs):
298
+ """Forward pass.
299
+
300
+ Returns:
301
+ tensor: output
302
+ """
303
+ output = xs[0]
304
+
305
+ if len(xs) == 2:
306
+ output += self.resConfUnit1(xs[1])
307
+
308
+ output = self.resConfUnit2(output)
309
+
310
+ output = nn.functional.interpolate(
311
+ output, scale_factor=2, mode="bilinear", align_corners=True
312
+ )
313
+
314
+ return output
315
+
316
+
317
+
318
+
319
+ class ResidualConvUnit_custom(nn.Module):
320
+ """Residual convolution module.
321
+ """
322
+
323
+ def __init__(self, features, activation, bn):
324
+ """Init.
325
+
326
+ Args:
327
+ features (int): number of features
328
+ """
329
+ super().__init__()
330
+
331
+ self.bn = bn
332
+
333
+ self.groups=1
334
+
335
+ self.conv1 = nn.Conv2d(
336
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
337
+ )
338
+
339
+ self.conv2 = nn.Conv2d(
340
+ features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
341
+ )
342
+
343
+ if self.bn==True:
344
+ self.bn1 = nn.BatchNorm2d(features)
345
+ self.bn2 = nn.BatchNorm2d(features)
346
+
347
+ self.activation = activation
348
+
349
+ self.skip_add = nn.quantized.FloatFunctional()
350
+
351
+ def forward(self, x):
352
+ """Forward pass.
353
+
354
+ Args:
355
+ x (tensor): input
356
+
357
+ Returns:
358
+ tensor: output
359
+ """
360
+
361
+ out = self.activation(x)
362
+ out = self.conv1(out)
363
+ if self.bn==True:
364
+ out = self.bn1(out)
365
+
366
+ out = self.activation(out)
367
+ out = self.conv2(out)
368
+ if self.bn==True:
369
+ out = self.bn2(out)
370
+
371
+ if self.groups > 1:
372
+ out = self.conv_merge(out)
373
+
374
+ return self.skip_add.add(out, x)
375
+
376
+ # return out + x
377
+
378
+
379
+ class FeatureFusionBlock_custom(nn.Module):
380
+ """Feature fusion block.
381
+ """
382
+
383
+ def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
384
+ """Init.
385
+
386
+ Args:
387
+ features (int): number of features
388
+ """
389
+ super(FeatureFusionBlock_custom, self).__init__()
390
+
391
+ self.deconv = deconv
392
+ self.align_corners = align_corners
393
+
394
+ self.groups=1
395
+
396
+ self.expand = expand
397
+ out_features = features
398
+ if self.expand==True:
399
+ out_features = features//2
400
+
401
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
402
+
403
+ self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
404
+ self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
405
+
406
+ self.skip_add = nn.quantized.FloatFunctional()
407
+
408
+ self.size=size
409
+
410
+ def forward(self, *xs, size=None):
411
+ """Forward pass.
412
+
413
+ Returns:
414
+ tensor: output
415
+ """
416
+ output = xs[0]
417
+
418
+ if len(xs) == 2:
419
+ res = self.resConfUnit1(xs[1])
420
+ output = self.skip_add.add(output, res)
421
+ # output += res
422
+
423
+ output = self.resConfUnit2(output)
424
+
425
+ if (size is None) and (self.size is None):
426
+ modifier = {"scale_factor": 2}
427
+ elif size is None:
428
+ modifier = {"size": self.size}
429
+ else:
430
+ modifier = {"size": size}
431
+
432
+ output = nn.functional.interpolate(
433
+ output, **modifier, mode="bilinear", align_corners=self.align_corners
434
+ )
435
+
436
+ output = self.out_conv(output)
437
+
438
+ return output
439
+
custom_midas_repo/midas/dpt_depth.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_model import BaseModel
5
+ from .blocks import (
6
+ FeatureFusionBlock_custom,
7
+ Interpolate,
8
+ _make_encoder,
9
+ forward_beit,
10
+ forward_swin,
11
+ forward_levit,
12
+ forward_vit,
13
+ )
14
+ from .backbones.levit import stem_b4_transpose
15
+ from custom_timm.models.layers import get_act_layer
16
+
17
+
18
+ def _make_fusion_block(features, use_bn, size = None):
19
+ return FeatureFusionBlock_custom(
20
+ features,
21
+ nn.ReLU(False),
22
+ deconv=False,
23
+ bn=use_bn,
24
+ expand=False,
25
+ align_corners=True,
26
+ size=size,
27
+ )
28
+
29
+
30
+ class DPT(BaseModel):
31
+ def __init__(
32
+ self,
33
+ head,
34
+ features=256,
35
+ backbone="vitb_rn50_384",
36
+ readout="project",
37
+ channels_last=False,
38
+ use_bn=False,
39
+ **kwargs
40
+ ):
41
+
42
+ super(DPT, self).__init__()
43
+
44
+ self.channels_last = channels_last
45
+
46
+ # For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
47
+ # hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
48
+ hooks = {
49
+ "beitl16_512": [5, 11, 17, 23],
50
+ "beitl16_384": [5, 11, 17, 23],
51
+ "beitb16_384": [2, 5, 8, 11],
52
+ "swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
53
+ "swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
54
+ "swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
55
+ "swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
56
+ "next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
57
+ "levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
58
+ "vitb_rn50_384": [0, 1, 8, 11],
59
+ "vitb16_384": [2, 5, 8, 11],
60
+ "vitl16_384": [5, 11, 17, 23],
61
+ }[backbone]
62
+
63
+ if "next_vit" in backbone:
64
+ in_features = {
65
+ "next_vit_large_6m": [96, 256, 512, 1024],
66
+ }[backbone]
67
+ else:
68
+ in_features = None
69
+
70
+ # Instantiate backbone and reassemble blocks
71
+ self.pretrained, self.scratch = _make_encoder(
72
+ backbone,
73
+ features,
74
+ False, # Set to true of you want to train from scratch, uses ImageNet weights
75
+ groups=1,
76
+ expand=False,
77
+ exportable=False,
78
+ hooks=hooks,
79
+ use_readout=readout,
80
+ in_features=in_features,
81
+ )
82
+
83
+ self.number_layers = len(hooks) if hooks is not None else 4
84
+ size_refinenet3 = None
85
+ self.scratch.stem_transpose = None
86
+
87
+ if "beit" in backbone:
88
+ self.forward_transformer = forward_beit
89
+ elif "swin" in backbone:
90
+ self.forward_transformer = forward_swin
91
+ elif "next_vit" in backbone:
92
+ from .backbones.next_vit import forward_next_vit
93
+ self.forward_transformer = forward_next_vit
94
+ elif "levit" in backbone:
95
+ self.forward_transformer = forward_levit
96
+ size_refinenet3 = 7
97
+ self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
98
+ else:
99
+ self.forward_transformer = forward_vit
100
+
101
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
104
+ if self.number_layers >= 4:
105
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
106
+
107
+ self.scratch.output_conv = head
108
+
109
+
110
+ def forward(self, x):
111
+ if self.channels_last == True:
112
+ x.contiguous(memory_format=torch.channels_last)
113
+
114
+ layers = self.forward_transformer(self.pretrained, x)
115
+ if self.number_layers == 3:
116
+ layer_1, layer_2, layer_3 = layers
117
+ else:
118
+ layer_1, layer_2, layer_3, layer_4 = layers
119
+
120
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
121
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
122
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
123
+ if self.number_layers >= 4:
124
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
125
+
126
+ if self.number_layers == 3:
127
+ path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
128
+ else:
129
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
130
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
131
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
132
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
133
+
134
+ if self.scratch.stem_transpose is not None:
135
+ path_1 = self.scratch.stem_transpose(path_1)
136
+
137
+ out = self.scratch.output_conv(path_1)
138
+
139
+ return out
140
+
141
+
142
+ class DPTDepthModel(DPT):
143
+ def __init__(self, path=None, non_negative=True, **kwargs):
144
+ features = kwargs["features"] if "features" in kwargs else 256
145
+ head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
146
+ head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
147
+ kwargs.pop("head_features_1", None)
148
+ kwargs.pop("head_features_2", None)
149
+
150
+ head = nn.Sequential(
151
+ nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
152
+ Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
153
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
154
+ nn.ReLU(True),
155
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
156
+ nn.ReLU(True) if non_negative else nn.Identity(),
157
+ nn.Identity(),
158
+ )
159
+
160
+ super().__init__(head, **kwargs)
161
+
162
+ if path is not None:
163
+ self.load(path)
164
+
165
+ def forward(self, x):
166
+ return super().forward(x).squeeze(dim=1)
custom_midas_repo/midas/midas_net.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=256, non_negative=True):
17
+ """Init.
18
+
19
+ Args:
20
+ path (str, optional): Path to saved model. Defaults to None.
21
+ features (int, optional): Number of features. Defaults to 256.
22
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
23
+ """
24
+ print("Loading weights: ", path)
25
+
26
+ super(MidasNet, self).__init__()
27
+
28
+ use_pretrained = False if path is None else True
29
+
30
+ self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
31
+
32
+ self.scratch.refinenet4 = FeatureFusionBlock(features)
33
+ self.scratch.refinenet3 = FeatureFusionBlock(features)
34
+ self.scratch.refinenet2 = FeatureFusionBlock(features)
35
+ self.scratch.refinenet1 = FeatureFusionBlock(features)
36
+
37
+ self.scratch.output_conv = nn.Sequential(
38
+ nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
39
+ Interpolate(scale_factor=2, mode="bilinear"),
40
+ nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
41
+ nn.ReLU(True),
42
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
43
+ nn.ReLU(True) if non_negative else nn.Identity(),
44
+ )
45
+
46
+ if path:
47
+ self.load(path)
48
+
49
+ def forward(self, x):
50
+ """Forward pass.
51
+
52
+ Args:
53
+ x (tensor): input data (image)
54
+
55
+ Returns:
56
+ tensor: depth
57
+ """
58
+
59
+ layer_1 = self.pretrained.layer1(x)
60
+ layer_2 = self.pretrained.layer2(layer_1)
61
+ layer_3 = self.pretrained.layer3(layer_2)
62
+ layer_4 = self.pretrained.layer4(layer_3)
63
+
64
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
65
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
66
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
67
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
68
+
69
+ path_4 = self.scratch.refinenet4(layer_4_rn)
70
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
71
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
72
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
73
+
74
+ out = self.scratch.output_conv(path_1)
75
+
76
+ return torch.squeeze(out, dim=1)
custom_midas_repo/midas/midas_net_custom.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MidashNet: Network for monocular depth estimation trained by mixing several datasets.
2
+ This file contains code that is adapted from
3
+ https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
4
+ """
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from .base_model import BaseModel
9
+ from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
10
+
11
+
12
+ class MidasNet_small(BaseModel):
13
+ """Network for monocular depth estimation.
14
+ """
15
+
16
+ def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
17
+ blocks={'expand': True}):
18
+ """Init.
19
+
20
+ Args:
21
+ path (str, optional): Path to saved model. Defaults to None.
22
+ features (int, optional): Number of features. Defaults to 256.
23
+ backbone (str, optional): Backbone network for encoder. Defaults to resnet50
24
+ """
25
+ print("Loading weights: ", path)
26
+
27
+ super(MidasNet_small, self).__init__()
28
+
29
+ use_pretrained = False if path else True
30
+
31
+ self.channels_last = channels_last
32
+ self.blocks = blocks
33
+ self.backbone = backbone
34
+
35
+ self.groups = 1
36
+
37
+ features1=features
38
+ features2=features
39
+ features3=features
40
+ features4=features
41
+ self.expand = False
42
+ if "expand" in self.blocks and self.blocks['expand'] == True:
43
+ self.expand = True
44
+ features1=features
45
+ features2=features*2
46
+ features3=features*4
47
+ features4=features*8
48
+
49
+ self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
50
+
51
+ self.scratch.activation = nn.ReLU(False)
52
+
53
+ self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
54
+ self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
55
+ self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
56
+ self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
57
+
58
+
59
+ self.scratch.output_conv = nn.Sequential(
60
+ nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
61
+ Interpolate(scale_factor=2, mode="bilinear"),
62
+ nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
63
+ self.scratch.activation,
64
+ nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
65
+ nn.ReLU(True) if non_negative else nn.Identity(),
66
+ nn.Identity(),
67
+ )
68
+
69
+ if path:
70
+ self.load(path)
71
+
72
+
73
+ def forward(self, x):
74
+ """Forward pass.
75
+
76
+ Args:
77
+ x (tensor): input data (image)
78
+
79
+ Returns:
80
+ tensor: depth
81
+ """
82
+ if self.channels_last==True:
83
+ print("self.channels_last = ", self.channels_last)
84
+ x.contiguous(memory_format=torch.channels_last)
85
+
86
+
87
+ layer_1 = self.pretrained.layer1(x)
88
+ layer_2 = self.pretrained.layer2(layer_1)
89
+ layer_3 = self.pretrained.layer3(layer_2)
90
+ layer_4 = self.pretrained.layer4(layer_3)
91
+
92
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
93
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
94
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
95
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
96
+
97
+
98
+ path_4 = self.scratch.refinenet4(layer_4_rn)
99
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
100
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
101
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
102
+
103
+ out = self.scratch.output_conv(path_1)
104
+
105
+ return torch.squeeze(out, dim=1)
106
+
107
+
108
+
109
+ def fuse_model(m):
110
+ prev_previous_type = nn.Identity()
111
+ prev_previous_name = ''
112
+ previous_type = nn.Identity()
113
+ previous_name = ''
114
+ for name, module in m.named_modules():
115
+ if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
116
+ # print("FUSED ", prev_previous_name, previous_name, name)
117
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
118
+ elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
119
+ # print("FUSED ", prev_previous_name, previous_name)
120
+ torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
121
+ # elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
122
+ # print("FUSED ", previous_name, name)
123
+ # torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
124
+
125
+ prev_previous_type = previous_type
126
+ prev_previous_name = previous_name
127
+ previous_type = type(module)
128
+ previous_name = name
custom_midas_repo/midas/model_loader.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+
4
+ from custom_midas_repo.midas.dpt_depth import DPTDepthModel
5
+ from custom_midas_repo.midas.midas_net import MidasNet
6
+ from custom_midas_repo.midas.midas_net_custom import MidasNet_small
7
+ from custom_midas_repo.midas.transforms import Resize, NormalizeImage, PrepareForNet
8
+
9
+ from torchvision.transforms import Compose
10
+
11
+ default_models = {
12
+ "dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
13
+ "dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
14
+ "dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
15
+ "dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
16
+ "dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
17
+ "dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
18
+ "dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
19
+ "dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
20
+ "dpt_levit_224": "weights/dpt_levit_224.pt",
21
+ "dpt_large_384": "weights/dpt_large_384.pt",
22
+ "dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
23
+ "midas_v21_384": "weights/midas_v21_384.pt",
24
+ "midas_v21_small_256": "weights/midas_v21_small_256.pt",
25
+ "openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
26
+ }
27
+
28
+
29
+ def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
30
+ """Load the specified network.
31
+
32
+ Args:
33
+ device (device): the torch device used
34
+ model_path (str): path to saved model
35
+ model_type (str): the type of the model to be loaded
36
+ optimize (bool): optimize the model to half-integer on CUDA?
37
+ height (int): inference encoder image height
38
+ square (bool): resize to a square resolution?
39
+
40
+ Returns:
41
+ The loaded network, the transform which prepares images as input to the network and the dimensions of the
42
+ network input
43
+ """
44
+ if "openvino" in model_type:
45
+ from openvino.runtime import Core
46
+
47
+ keep_aspect_ratio = not square
48
+
49
+ if model_type == "dpt_beit_large_512":
50
+ model = DPTDepthModel(
51
+ path=model_path,
52
+ backbone="beitl16_512",
53
+ non_negative=True,
54
+ )
55
+ net_w, net_h = 512, 512
56
+ resize_mode = "minimal"
57
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
58
+
59
+ elif model_type == "dpt_beit_large_384":
60
+ model = DPTDepthModel(
61
+ path=model_path,
62
+ backbone="beitl16_384",
63
+ non_negative=True,
64
+ )
65
+ net_w, net_h = 384, 384
66
+ resize_mode = "minimal"
67
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
68
+
69
+ elif model_type == "dpt_beit_base_384":
70
+ model = DPTDepthModel(
71
+ path=model_path,
72
+ backbone="beitb16_384",
73
+ non_negative=True,
74
+ )
75
+ net_w, net_h = 384, 384
76
+ resize_mode = "minimal"
77
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
78
+
79
+ elif model_type == "dpt_swin2_large_384":
80
+ model = DPTDepthModel(
81
+ path=model_path,
82
+ backbone="swin2l24_384",
83
+ non_negative=True,
84
+ )
85
+ net_w, net_h = 384, 384
86
+ keep_aspect_ratio = False
87
+ resize_mode = "minimal"
88
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
89
+
90
+ elif model_type == "dpt_swin2_base_384":
91
+ model = DPTDepthModel(
92
+ path=model_path,
93
+ backbone="swin2b24_384",
94
+ non_negative=True,
95
+ )
96
+ net_w, net_h = 384, 384
97
+ keep_aspect_ratio = False
98
+ resize_mode = "minimal"
99
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
100
+
101
+ elif model_type == "dpt_swin2_tiny_256":
102
+ model = DPTDepthModel(
103
+ path=model_path,
104
+ backbone="swin2t16_256",
105
+ non_negative=True,
106
+ )
107
+ net_w, net_h = 256, 256
108
+ keep_aspect_ratio = False
109
+ resize_mode = "minimal"
110
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
111
+
112
+ elif model_type == "dpt_swin_large_384":
113
+ model = DPTDepthModel(
114
+ path=model_path,
115
+ backbone="swinl12_384",
116
+ non_negative=True,
117
+ )
118
+ net_w, net_h = 384, 384
119
+ keep_aspect_ratio = False
120
+ resize_mode = "minimal"
121
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
122
+
123
+ elif model_type == "dpt_next_vit_large_384":
124
+ model = DPTDepthModel(
125
+ path=model_path,
126
+ backbone="next_vit_large_6m",
127
+ non_negative=True,
128
+ )
129
+ net_w, net_h = 384, 384
130
+ resize_mode = "minimal"
131
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
132
+
133
+ # We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
134
+ # to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
135
+ # https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
136
+ # (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
137
+ elif model_type == "dpt_levit_224":
138
+ model = DPTDepthModel(
139
+ path=model_path,
140
+ backbone="levit_384",
141
+ non_negative=True,
142
+ head_features_1=64,
143
+ head_features_2=8,
144
+ )
145
+ net_w, net_h = 224, 224
146
+ keep_aspect_ratio = False
147
+ resize_mode = "minimal"
148
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
149
+
150
+ elif model_type == "dpt_large_384":
151
+ model = DPTDepthModel(
152
+ path=model_path,
153
+ backbone="vitl16_384",
154
+ non_negative=True,
155
+ )
156
+ net_w, net_h = 384, 384
157
+ resize_mode = "minimal"
158
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
159
+
160
+ elif model_type == "dpt_hybrid_384":
161
+ model = DPTDepthModel(
162
+ path=model_path,
163
+ backbone="vitb_rn50_384",
164
+ non_negative=True,
165
+ )
166
+ net_w, net_h = 384, 384
167
+ resize_mode = "minimal"
168
+ normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
169
+
170
+ elif model_type == "midas_v21_384":
171
+ model = MidasNet(model_path, non_negative=True)
172
+ net_w, net_h = 384, 384
173
+ resize_mode = "upper_bound"
174
+ normalization = NormalizeImage(
175
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
176
+ )
177
+
178
+ elif model_type == "midas_v21_small_256":
179
+ model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
180
+ non_negative=True, blocks={'expand': True})
181
+ net_w, net_h = 256, 256
182
+ resize_mode = "upper_bound"
183
+ normalization = NormalizeImage(
184
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
185
+ )
186
+
187
+ elif model_type == "openvino_midas_v21_small_256":
188
+ ie = Core()
189
+ uncompiled_model = ie.read_model(model=model_path)
190
+ model = ie.compile_model(uncompiled_model, "CPU")
191
+ net_w, net_h = 256, 256
192
+ resize_mode = "upper_bound"
193
+ normalization = NormalizeImage(
194
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
195
+ )
196
+
197
+ else:
198
+ print(f"model_type '{model_type}' not implemented, use: --model_type large")
199
+ assert False
200
+
201
+ if not "openvino" in model_type:
202
+ print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
203
+ else:
204
+ print("Model loaded, optimized with OpenVINO")
205
+
206
+ if "openvino" in model_type:
207
+ keep_aspect_ratio = False
208
+
209
+ if height is not None:
210
+ net_w, net_h = height, height
211
+
212
+ transform = Compose(
213
+ [
214
+ Resize(
215
+ net_w,
216
+ net_h,
217
+ resize_target=None,
218
+ keep_aspect_ratio=keep_aspect_ratio,
219
+ ensure_multiple_of=32,
220
+ resize_method=resize_mode,
221
+ image_interpolation_method=cv2.INTER_CUBIC,
222
+ ),
223
+ normalization,
224
+ PrepareForNet(),
225
+ ]
226
+ )
227
+
228
+ if not "openvino" in model_type:
229
+ model.eval()
230
+
231
+ if optimize and (device == torch.device("cuda")):
232
+ if not "openvino" in model_type:
233
+ model = model.to(memory_format=torch.channels_last)
234
+ model = model.half()
235
+ else:
236
+ print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
237
+ exit()
238
+
239
+ if not "openvino" in model_type:
240
+ model.to(device)
241
+
242
+ return model, transform, net_w, net_h
custom_midas_repo/midas/transforms.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import math
4
+
5
+
6
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
7
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
8
+
9
+ Args:
10
+ sample (dict): sample
11
+ size (tuple): image size
12
+
13
+ Returns:
14
+ tuple: new size
15
+ """
16
+ shape = list(sample["disparity"].shape)
17
+
18
+ if shape[0] >= size[0] and shape[1] >= size[1]:
19
+ return sample
20
+
21
+ scale = [0, 0]
22
+ scale[0] = size[0] / shape[0]
23
+ scale[1] = size[1] / shape[1]
24
+
25
+ scale = max(scale)
26
+
27
+ shape[0] = math.ceil(scale * shape[0])
28
+ shape[1] = math.ceil(scale * shape[1])
29
+
30
+ # resize
31
+ sample["image"] = cv2.resize(
32
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
33
+ )
34
+
35
+ sample["disparity"] = cv2.resize(
36
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
37
+ )
38
+ sample["mask"] = cv2.resize(
39
+ sample["mask"].astype(np.float32),
40
+ tuple(shape[::-1]),
41
+ interpolation=cv2.INTER_NEAREST,
42
+ )
43
+ sample["mask"] = sample["mask"].astype(bool)
44
+
45
+ return tuple(shape)
46
+
47
+
48
+ class Resize(object):
49
+ """Resize sample to given size (width, height).
50
+ """
51
+
52
+ def __init__(
53
+ self,
54
+ width,
55
+ height,
56
+ resize_target=True,
57
+ keep_aspect_ratio=False,
58
+ ensure_multiple_of=1,
59
+ resize_method="lower_bound",
60
+ image_interpolation_method=cv2.INTER_AREA,
61
+ ):
62
+ """Init.
63
+
64
+ Args:
65
+ width (int): desired output width
66
+ height (int): desired output height
67
+ resize_target (bool, optional):
68
+ True: Resize the full sample (image, mask, target).
69
+ False: Resize image only.
70
+ Defaults to True.
71
+ keep_aspect_ratio (bool, optional):
72
+ True: Keep the aspect ratio of the input sample.
73
+ Output sample might not have the given width and height, and
74
+ resize behaviour depends on the parameter 'resize_method'.
75
+ Defaults to False.
76
+ ensure_multiple_of (int, optional):
77
+ Output width and height is constrained to be multiple of this parameter.
78
+ Defaults to 1.
79
+ resize_method (str, optional):
80
+ "lower_bound": Output will be at least as large as the given size.
81
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
82
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
83
+ Defaults to "lower_bound".
84
+ """
85
+ self.__width = width
86
+ self.__height = height
87
+
88
+ self.__resize_target = resize_target
89
+ self.__keep_aspect_ratio = keep_aspect_ratio
90
+ self.__multiple_of = ensure_multiple_of
91
+ self.__resize_method = resize_method
92
+ self.__image_interpolation_method = image_interpolation_method
93
+
94
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
95
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
96
+
97
+ if max_val is not None and y > max_val:
98
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
99
+
100
+ if y < min_val:
101
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
102
+
103
+ return y
104
+
105
+ def get_size(self, width, height):
106
+ # determine new height and width
107
+ scale_height = self.__height / height
108
+ scale_width = self.__width / width
109
+
110
+ if self.__keep_aspect_ratio:
111
+ if self.__resize_method == "lower_bound":
112
+ # scale such that output size is lower bound
113
+ if scale_width > scale_height:
114
+ # fit width
115
+ scale_height = scale_width
116
+ else:
117
+ # fit height
118
+ scale_width = scale_height
119
+ elif self.__resize_method == "upper_bound":
120
+ # scale such that output size is upper bound
121
+ if scale_width < scale_height:
122
+ # fit width
123
+ scale_height = scale_width
124
+ else:
125
+ # fit height
126
+ scale_width = scale_height
127
+ elif self.__resize_method == "minimal":
128
+ # scale as least as possbile
129
+ if abs(1 - scale_width) < abs(1 - scale_height):
130
+ # fit width
131
+ scale_height = scale_width
132
+ else:
133
+ # fit height
134
+ scale_width = scale_height
135
+ else:
136
+ raise ValueError(
137
+ f"resize_method {self.__resize_method} not implemented"
138
+ )
139
+
140
+ if self.__resize_method == "lower_bound":
141
+ new_height = self.constrain_to_multiple_of(
142
+ scale_height * height, min_val=self.__height
143
+ )
144
+ new_width = self.constrain_to_multiple_of(
145
+ scale_width * width, min_val=self.__width
146
+ )
147
+ elif self.__resize_method == "upper_bound":
148
+ new_height = self.constrain_to_multiple_of(
149
+ scale_height * height, max_val=self.__height
150
+ )
151
+ new_width = self.constrain_to_multiple_of(
152
+ scale_width * width, max_val=self.__width
153
+ )
154
+ elif self.__resize_method == "minimal":
155
+ new_height = self.constrain_to_multiple_of(scale_height * height)
156
+ new_width = self.constrain_to_multiple_of(scale_width * width)
157
+ else:
158
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
159
+
160
+ return (new_width, new_height)
161
+
162
+ def __call__(self, sample):
163
+ width, height = self.get_size(
164
+ sample["image"].shape[1], sample["image"].shape[0]
165
+ )
166
+
167
+ # resize sample
168
+ sample["image"] = cv2.resize(
169
+ sample["image"],
170
+ (width, height),
171
+ interpolation=self.__image_interpolation_method,
172
+ )
173
+
174
+ if self.__resize_target:
175
+ if "disparity" in sample:
176
+ sample["disparity"] = cv2.resize(
177
+ sample["disparity"],
178
+ (width, height),
179
+ interpolation=cv2.INTER_NEAREST,
180
+ )
181
+
182
+ if "depth" in sample:
183
+ sample["depth"] = cv2.resize(
184
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
185
+ )
186
+
187
+ sample["mask"] = cv2.resize(
188
+ sample["mask"].astype(np.float32),
189
+ (width, height),
190
+ interpolation=cv2.INTER_NEAREST,
191
+ )
192
+ sample["mask"] = sample["mask"].astype(bool)
193
+
194
+ return sample
195
+
196
+
197
+ class NormalizeImage(object):
198
+ """Normlize image by given mean and std.
199
+ """
200
+
201
+ def __init__(self, mean, std):
202
+ self.__mean = mean
203
+ self.__std = std
204
+
205
+ def __call__(self, sample):
206
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
207
+
208
+ return sample
209
+
210
+
211
+ class PrepareForNet(object):
212
+ """Prepare sample for usage as network input.
213
+ """
214
+
215
+ def __init__(self):
216
+ pass
217
+
218
+ def __call__(self, sample):
219
+ image = np.transpose(sample["image"], (2, 0, 1))
220
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
221
+
222
+ if "mask" in sample:
223
+ sample["mask"] = sample["mask"].astype(np.float32)
224
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
225
+
226
+ if "disparity" in sample:
227
+ disparity = sample["disparity"].astype(np.float32)
228
+ sample["disparity"] = np.ascontiguousarray(disparity)
229
+
230
+ if "depth" in sample:
231
+ depth = sample["depth"].astype(np.float32)
232
+ sample["depth"] = np.ascontiguousarray(depth)
233
+
234
+ return sample
custom_mmpkg/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #Dummy file ensuring this package will be recognized
custom_mmpkg/custom_mmcv/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ # flake8: noqa
3
+ from .arraymisc import *
4
+ from .fileio import *
5
+ from .image import *
6
+ from .utils import *
7
+ from .version import *
8
+ from .video import *
9
+ from .visualization import *
10
+
11
+ # The following modules are not imported to this level, so mmcv may be used
12
+ # without PyTorch.
13
+ # - runner
14
+ # - parallel
15
+ # - op
custom_mmpkg/custom_mmcv/arraymisc/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .quantization import dequantize, quantize
3
+
4
+ __all__ = ['quantize', 'dequantize']
custom_mmpkg/custom_mmcv/arraymisc/quantization.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import numpy as np
3
+
4
+
5
+ def quantize(arr, min_val, max_val, levels, dtype=np.int64):
6
+ """Quantize an array of (-inf, inf) to [0, levels-1].
7
+
8
+ Args:
9
+ arr (ndarray): Input array.
10
+ min_val (scalar): Minimum value to be clipped.
11
+ max_val (scalar): Maximum value to be clipped.
12
+ levels (int): Quantization levels.
13
+ dtype (np.type): The type of the quantized array.
14
+
15
+ Returns:
16
+ tuple: Quantized array.
17
+ """
18
+ if not (isinstance(levels, int) and levels > 1):
19
+ raise ValueError(
20
+ f'levels must be a positive integer, but got {levels}')
21
+ if min_val >= max_val:
22
+ raise ValueError(
23
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
24
+
25
+ arr = np.clip(arr, min_val, max_val) - min_val
26
+ quantized_arr = np.minimum(
27
+ np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
28
+
29
+ return quantized_arr
30
+
31
+
32
+ def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
33
+ """Dequantize an array.
34
+
35
+ Args:
36
+ arr (ndarray): Input array.
37
+ min_val (scalar): Minimum value to be clipped.
38
+ max_val (scalar): Maximum value to be clipped.
39
+ levels (int): Quantization levels.
40
+ dtype (np.type): The type of the dequantized array.
41
+
42
+ Returns:
43
+ tuple: Dequantized array.
44
+ """
45
+ if not (isinstance(levels, int) and levels > 1):
46
+ raise ValueError(
47
+ f'levels must be a positive integer, but got {levels}')
48
+ if min_val >= max_val:
49
+ raise ValueError(
50
+ f'min_val ({min_val}) must be smaller than max_val ({max_val})')
51
+
52
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
53
+ min_val) / levels + min_val
54
+
55
+ return dequantized_arr
custom_mmpkg/custom_mmcv/cnn/__init__.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .alexnet import AlexNet
3
+ # yapf: disable
4
+ from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
5
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
6
+ ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
7
+ ConvTranspose2d, ConvTranspose3d, ConvWS2d,
8
+ DepthwiseSeparableConvModule, GeneralizedAttention,
9
+ HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
10
+ NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
11
+ build_activation_layer, build_conv_layer,
12
+ build_norm_layer, build_padding_layer, build_plugin_layer,
13
+ build_upsample_layer, conv_ws_2d, is_norm)
14
+ from .builder import MODELS, build_model_from_cfg
15
+ # yapf: enable
16
+ from .resnet import ResNet, make_res_layer
17
+ from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
18
+ NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
19
+ XavierInit, bias_init_with_prob, caffe2_xavier_init,
20
+ constant_init, fuse_conv_bn, get_model_complexity_info,
21
+ initialize, kaiming_init, normal_init, trunc_normal_init,
22
+ uniform_init, xavier_init)
23
+ from .vgg import VGG, make_vgg_layer
24
+
25
+ __all__ = [
26
+ 'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
27
+ 'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
28
+ 'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
29
+ 'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
30
+ 'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
31
+ 'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
32
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
33
+ 'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
34
+ 'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
35
+ 'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
36
+ 'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
37
+ 'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
38
+ 'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
39
+ 'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
40
+ 'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
41
+ ]
custom_mmpkg/custom_mmcv/cnn/alexnet.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import logging
3
+
4
+ import torch.nn as nn
5
+
6
+
7
+ class AlexNet(nn.Module):
8
+ """AlexNet backbone.
9
+
10
+ Args:
11
+ num_classes (int): number of classes for classification.
12
+ """
13
+
14
+ def __init__(self, num_classes=-1):
15
+ super(AlexNet, self).__init__()
16
+ self.num_classes = num_classes
17
+ self.features = nn.Sequential(
18
+ nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
19
+ nn.ReLU(inplace=True),
20
+ nn.MaxPool2d(kernel_size=3, stride=2),
21
+ nn.Conv2d(64, 192, kernel_size=5, padding=2),
22
+ nn.ReLU(inplace=True),
23
+ nn.MaxPool2d(kernel_size=3, stride=2),
24
+ nn.Conv2d(192, 384, kernel_size=3, padding=1),
25
+ nn.ReLU(inplace=True),
26
+ nn.Conv2d(384, 256, kernel_size=3, padding=1),
27
+ nn.ReLU(inplace=True),
28
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
29
+ nn.ReLU(inplace=True),
30
+ nn.MaxPool2d(kernel_size=3, stride=2),
31
+ )
32
+ if self.num_classes > 0:
33
+ self.classifier = nn.Sequential(
34
+ nn.Dropout(),
35
+ nn.Linear(256 * 6 * 6, 4096),
36
+ nn.ReLU(inplace=True),
37
+ nn.Dropout(),
38
+ nn.Linear(4096, 4096),
39
+ nn.ReLU(inplace=True),
40
+ nn.Linear(4096, num_classes),
41
+ )
42
+
43
+ def init_weights(self, pretrained=None):
44
+ if isinstance(pretrained, str):
45
+ logger = logging.getLogger()
46
+ from ..runner import load_checkpoint
47
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
48
+ elif pretrained is None:
49
+ # use default initializer
50
+ pass
51
+ else:
52
+ raise TypeError('pretrained must be a str or None')
53
+
54
+ def forward(self, x):
55
+
56
+ x = self.features(x)
57
+ if self.num_classes > 0:
58
+ x = x.view(x.size(0), 256 * 6 * 6)
59
+ x = self.classifier(x)
60
+
61
+ return x
custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from .activation import build_activation_layer
3
+ from .context_block import ContextBlock
4
+ from .conv import build_conv_layer
5
+ from .conv2d_adaptive_padding import Conv2dAdaptivePadding
6
+ from .conv_module import ConvModule
7
+ from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
8
+ from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
9
+ from .drop import Dropout, DropPath
10
+ from .generalized_attention import GeneralizedAttention
11
+ from .hsigmoid import HSigmoid
12
+ from .hswish import HSwish
13
+ from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
14
+ from .norm import build_norm_layer, is_norm
15
+ from .padding import build_padding_layer
16
+ from .plugin import build_plugin_layer
17
+ from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
18
+ PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
19
+ from .scale import Scale
20
+ from .swish import Swish
21
+ from .upsample import build_upsample_layer
22
+ from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
23
+ Linear, MaxPool2d, MaxPool3d)
24
+
25
+ __all__ = [
26
+ 'ConvModule', 'build_activation_layer', 'build_conv_layer',
27
+ 'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
28
+ 'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
29
+ 'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
30
+ 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
31
+ 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
32
+ 'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
33
+ 'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
34
+ 'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
35
+ ]
custom_mmpkg/custom_mmcv/cnn/bricks/activation.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
7
+ from .registry import ACTIVATION_LAYERS
8
+
9
+ for module in [
10
+ nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
11
+ nn.Sigmoid, nn.Tanh
12
+ ]:
13
+ ACTIVATION_LAYERS.register_module(module=module)
14
+
15
+
16
+ @ACTIVATION_LAYERS.register_module(name='Clip')
17
+ @ACTIVATION_LAYERS.register_module()
18
+ class Clamp(nn.Module):
19
+ """Clamp activation layer.
20
+
21
+ This activation function is to clamp the feature map value within
22
+ :math:`[min, max]`. More details can be found in ``torch.clamp()``.
23
+
24
+ Args:
25
+ min (Number | optional): Lower-bound of the range to be clamped to.
26
+ Default to -1.
27
+ max (Number | optional): Upper-bound of the range to be clamped to.
28
+ Default to 1.
29
+ """
30
+
31
+ def __init__(self, min=-1., max=1.):
32
+ super(Clamp, self).__init__()
33
+ self.min = min
34
+ self.max = max
35
+
36
+ def forward(self, x):
37
+ """Forward function.
38
+
39
+ Args:
40
+ x (torch.Tensor): The input tensor.
41
+
42
+ Returns:
43
+ torch.Tensor: Clamped tensor.
44
+ """
45
+ return torch.clamp(x, min=self.min, max=self.max)
46
+
47
+
48
+ class GELU(nn.Module):
49
+ r"""Applies the Gaussian Error Linear Units function:
50
+
51
+ .. math::
52
+ \text{GELU}(x) = x * \Phi(x)
53
+ where :math:`\Phi(x)` is the Cumulative Distribution Function for
54
+ Gaussian Distribution.
55
+
56
+ Shape:
57
+ - Input: :math:`(N, *)` where `*` means, any number of additional
58
+ dimensions
59
+ - Output: :math:`(N, *)`, same shape as the input
60
+
61
+ .. image:: scripts/activation_images/GELU.png
62
+
63
+ Examples::
64
+
65
+ >>> m = nn.GELU()
66
+ >>> input = torch.randn(2)
67
+ >>> output = m(input)
68
+ """
69
+
70
+ def forward(self, input):
71
+ return F.gelu(input)
72
+
73
+
74
+ if (TORCH_VERSION == 'parrots'
75
+ or digit_version(TORCH_VERSION) < digit_version('1.4')):
76
+ ACTIVATION_LAYERS.register_module(module=GELU)
77
+ else:
78
+ ACTIVATION_LAYERS.register_module(module=nn.GELU)
79
+
80
+
81
+ def build_activation_layer(cfg):
82
+ """Build activation layer.
83
+
84
+ Args:
85
+ cfg (dict): The activation layer config, which should contain:
86
+ - type (str): Layer type.
87
+ - layer args: Args needed to instantiate an activation layer.
88
+
89
+ Returns:
90
+ nn.Module: Created activation layer.
91
+ """
92
+ return build_from_cfg(cfg, ACTIVATION_LAYERS)
custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ from torch import nn
4
+
5
+ from ..utils import constant_init, kaiming_init
6
+ from .registry import PLUGIN_LAYERS
7
+
8
+
9
+ def last_zero_init(m):
10
+ if isinstance(m, nn.Sequential):
11
+ constant_init(m[-1], val=0)
12
+ else:
13
+ constant_init(m, val=0)
14
+
15
+
16
+ @PLUGIN_LAYERS.register_module()
17
+ class ContextBlock(nn.Module):
18
+ """ContextBlock module in GCNet.
19
+
20
+ See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
21
+ (https://arxiv.org/abs/1904.11492) for details.
22
+
23
+ Args:
24
+ in_channels (int): Channels of the input feature map.
25
+ ratio (float): Ratio of channels of transform bottleneck
26
+ pooling_type (str): Pooling method for context modeling.
27
+ Options are 'att' and 'avg', stand for attention pooling and
28
+ average pooling respectively. Default: 'att'.
29
+ fusion_types (Sequence[str]): Fusion method for feature fusion,
30
+ Options are 'channels_add', 'channel_mul', stand for channelwise
31
+ addition and multiplication respectively. Default: ('channel_add',)
32
+ """
33
+
34
+ _abbr_ = 'context_block'
35
+
36
+ def __init__(self,
37
+ in_channels,
38
+ ratio,
39
+ pooling_type='att',
40
+ fusion_types=('channel_add', )):
41
+ super(ContextBlock, self).__init__()
42
+ assert pooling_type in ['avg', 'att']
43
+ assert isinstance(fusion_types, (list, tuple))
44
+ valid_fusion_types = ['channel_add', 'channel_mul']
45
+ assert all([f in valid_fusion_types for f in fusion_types])
46
+ assert len(fusion_types) > 0, 'at least one fusion should be used'
47
+ self.in_channels = in_channels
48
+ self.ratio = ratio
49
+ self.planes = int(in_channels * ratio)
50
+ self.pooling_type = pooling_type
51
+ self.fusion_types = fusion_types
52
+ if pooling_type == 'att':
53
+ self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
54
+ self.softmax = nn.Softmax(dim=2)
55
+ else:
56
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
57
+ if 'channel_add' in fusion_types:
58
+ self.channel_add_conv = nn.Sequential(
59
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
60
+ nn.LayerNorm([self.planes, 1, 1]),
61
+ nn.ReLU(inplace=True), # yapf: disable
62
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
63
+ else:
64
+ self.channel_add_conv = None
65
+ if 'channel_mul' in fusion_types:
66
+ self.channel_mul_conv = nn.Sequential(
67
+ nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
68
+ nn.LayerNorm([self.planes, 1, 1]),
69
+ nn.ReLU(inplace=True), # yapf: disable
70
+ nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
71
+ else:
72
+ self.channel_mul_conv = None
73
+ self.reset_parameters()
74
+
75
+ def reset_parameters(self):
76
+ if self.pooling_type == 'att':
77
+ kaiming_init(self.conv_mask, mode='fan_in')
78
+ self.conv_mask.inited = True
79
+
80
+ if self.channel_add_conv is not None:
81
+ last_zero_init(self.channel_add_conv)
82
+ if self.channel_mul_conv is not None:
83
+ last_zero_init(self.channel_mul_conv)
84
+
85
+ def spatial_pool(self, x):
86
+ batch, channel, height, width = x.size()
87
+ if self.pooling_type == 'att':
88
+ input_x = x
89
+ # [N, C, H * W]
90
+ input_x = input_x.view(batch, channel, height * width)
91
+ # [N, 1, C, H * W]
92
+ input_x = input_x.unsqueeze(1)
93
+ # [N, 1, H, W]
94
+ context_mask = self.conv_mask(x)
95
+ # [N, 1, H * W]
96
+ context_mask = context_mask.view(batch, 1, height * width)
97
+ # [N, 1, H * W]
98
+ context_mask = self.softmax(context_mask)
99
+ # [N, 1, H * W, 1]
100
+ context_mask = context_mask.unsqueeze(-1)
101
+ # [N, 1, C, 1]
102
+ context = torch.matmul(input_x, context_mask)
103
+ # [N, C, 1, 1]
104
+ context = context.view(batch, channel, 1, 1)
105
+ else:
106
+ # [N, C, 1, 1]
107
+ context = self.avg_pool(x)
108
+
109
+ return context
110
+
111
+ def forward(self, x):
112
+ # [N, C, 1, 1]
113
+ context = self.spatial_pool(x)
114
+
115
+ out = x
116
+ if self.channel_mul_conv is not None:
117
+ # [N, C, 1, 1]
118
+ channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
119
+ out = out * channel_mul_term
120
+ if self.channel_add_conv is not None:
121
+ # [N, C, 1, 1]
122
+ channel_add_term = self.channel_add_conv(context)
123
+ out = out + channel_add_term
124
+
125
+ return out
custom_mmpkg/custom_mmcv/cnn/bricks/conv.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from torch import nn
3
+
4
+ from .registry import CONV_LAYERS
5
+
6
+ CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
7
+ CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
8
+ CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
9
+ CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
10
+
11
+
12
+ def build_conv_layer(cfg, *args, **kwargs):
13
+ """Build convolution layer.
14
+
15
+ Args:
16
+ cfg (None or dict): The conv layer config, which should contain:
17
+ - type (str): Layer type.
18
+ - layer args: Args needed to instantiate an conv layer.
19
+ args (argument list): Arguments passed to the `__init__`
20
+ method of the corresponding conv layer.
21
+ kwargs (keyword arguments): Keyword arguments passed to the `__init__`
22
+ method of the corresponding conv layer.
23
+
24
+ Returns:
25
+ nn.Module: Created conv layer.
26
+ """
27
+ if cfg is None:
28
+ cfg_ = dict(type='Conv2d')
29
+ else:
30
+ if not isinstance(cfg, dict):
31
+ raise TypeError('cfg must be a dict')
32
+ if 'type' not in cfg:
33
+ raise KeyError('the cfg dict must contain the key "type"')
34
+ cfg_ = cfg.copy()
35
+
36
+ layer_type = cfg_.pop('type')
37
+ if layer_type not in CONV_LAYERS:
38
+ raise KeyError(f'Unrecognized norm type {layer_type}')
39
+ else:
40
+ conv_layer = CONV_LAYERS.get(layer_type)
41
+
42
+ layer = conv_layer(*args, **kwargs, **cfg_)
43
+
44
+ return layer
custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ from .registry import CONV_LAYERS
8
+
9
+
10
+ @CONV_LAYERS.register_module()
11
+ class Conv2dAdaptivePadding(nn.Conv2d):
12
+ """Implementation of 2D convolution in tensorflow with `padding` as "same",
13
+ which applies padding to input (if needed) so that input image gets fully
14
+ covered by filter and stride you specified. For stride 1, this will ensure
15
+ that output image size is same as input. For stride of 2, output dimensions
16
+ will be half, for example.
17
+
18
+ Args:
19
+ in_channels (int): Number of channels in the input image
20
+ out_channels (int): Number of channels produced by the convolution
21
+ kernel_size (int or tuple): Size of the convolving kernel
22
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
23
+ padding (int or tuple, optional): Zero-padding added to both sides of
24
+ the input. Default: 0
25
+ dilation (int or tuple, optional): Spacing between kernel elements.
26
+ Default: 1
27
+ groups (int, optional): Number of blocked connections from input
28
+ channels to output channels. Default: 1
29
+ bias (bool, optional): If ``True``, adds a learnable bias to the
30
+ output. Default: ``True``
31
+ """
32
+
33
+ def __init__(self,
34
+ in_channels,
35
+ out_channels,
36
+ kernel_size,
37
+ stride=1,
38
+ padding=0,
39
+ dilation=1,
40
+ groups=1,
41
+ bias=True):
42
+ super().__init__(in_channels, out_channels, kernel_size, stride, 0,
43
+ dilation, groups, bias)
44
+
45
+ def forward(self, x):
46
+ img_h, img_w = x.size()[-2:]
47
+ kernel_h, kernel_w = self.weight.size()[-2:]
48
+ stride_h, stride_w = self.stride
49
+ output_h = math.ceil(img_h / stride_h)
50
+ output_w = math.ceil(img_w / stride_w)
51
+ pad_h = (
52
+ max((output_h - 1) * self.stride[0] +
53
+ (kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
54
+ pad_w = (
55
+ max((output_w - 1) * self.stride[1] +
56
+ (kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
57
+ if pad_h > 0 or pad_w > 0:
58
+ x = F.pad(x, [
59
+ pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
60
+ ])
61
+ return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
62
+ self.dilation, self.groups)
custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import warnings
3
+
4
+ import torch.nn as nn
5
+
6
+ from custom_mmpkg.custom_mmcv.utils import _BatchNorm, _InstanceNorm
7
+ from ..utils import constant_init, kaiming_init
8
+ from .activation import build_activation_layer
9
+ from .conv import build_conv_layer
10
+ from .norm import build_norm_layer
11
+ from .padding import build_padding_layer
12
+ from .registry import PLUGIN_LAYERS
13
+
14
+
15
+ @PLUGIN_LAYERS.register_module()
16
+ class ConvModule(nn.Module):
17
+ """A conv block that bundles conv/norm/activation layers.
18
+
19
+ This block simplifies the usage of convolution layers, which are commonly
20
+ used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
21
+ It is based upon three build methods: `build_conv_layer()`,
22
+ `build_norm_layer()` and `build_activation_layer()`.
23
+
24
+ Besides, we add some additional features in this module.
25
+ 1. Automatically set `bias` of the conv layer.
26
+ 2. Spectral norm is supported.
27
+ 3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
28
+ supports zero and circular padding, and we add "reflect" padding mode.
29
+
30
+ Args:
31
+ in_channels (int): Number of channels in the input feature map.
32
+ Same as that in ``nn._ConvNd``.
33
+ out_channels (int): Number of channels produced by the convolution.
34
+ Same as that in ``nn._ConvNd``.
35
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
36
+ Same as that in ``nn._ConvNd``.
37
+ stride (int | tuple[int]): Stride of the convolution.
38
+ Same as that in ``nn._ConvNd``.
39
+ padding (int | tuple[int]): Zero-padding added to both sides of
40
+ the input. Same as that in ``nn._ConvNd``.
41
+ dilation (int | tuple[int]): Spacing between kernel elements.
42
+ Same as that in ``nn._ConvNd``.
43
+ groups (int): Number of blocked connections from input channels to
44
+ output channels. Same as that in ``nn._ConvNd``.
45
+ bias (bool | str): If specified as `auto`, it will be decided by the
46
+ norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
47
+ False. Default: "auto".
48
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
49
+ which means using conv2d.
50
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
51
+ act_cfg (dict): Config dict for activation layer.
52
+ Default: dict(type='ReLU').
53
+ inplace (bool): Whether to use inplace mode for activation.
54
+ Default: True.
55
+ with_spectral_norm (bool): Whether use spectral norm in conv module.
56
+ Default: False.
57
+ padding_mode (str): If the `padding_mode` has not been supported by
58
+ current `Conv2d` in PyTorch, we will use our own padding layer
59
+ instead. Currently, we support ['zeros', 'circular'] with official
60
+ implementation and ['reflect'] with our own implementation.
61
+ Default: 'zeros'.
62
+ order (tuple[str]): The order of conv/norm/activation layers. It is a
63
+ sequence of "conv", "norm" and "act". Common examples are
64
+ ("conv", "norm", "act") and ("act", "conv", "norm").
65
+ Default: ('conv', 'norm', 'act').
66
+ """
67
+
68
+ _abbr_ = 'conv_block'
69
+
70
+ def __init__(self,
71
+ in_channels,
72
+ out_channels,
73
+ kernel_size,
74
+ stride=1,
75
+ padding=0,
76
+ dilation=1,
77
+ groups=1,
78
+ bias='auto',
79
+ conv_cfg=None,
80
+ norm_cfg=None,
81
+ act_cfg=dict(type='ReLU'),
82
+ inplace=True,
83
+ with_spectral_norm=False,
84
+ padding_mode='zeros',
85
+ order=('conv', 'norm', 'act')):
86
+ super(ConvModule, self).__init__()
87
+ assert conv_cfg is None or isinstance(conv_cfg, dict)
88
+ assert norm_cfg is None or isinstance(norm_cfg, dict)
89
+ assert act_cfg is None or isinstance(act_cfg, dict)
90
+ official_padding_mode = ['zeros', 'circular']
91
+ self.conv_cfg = conv_cfg
92
+ self.norm_cfg = norm_cfg
93
+ self.act_cfg = act_cfg
94
+ self.inplace = inplace
95
+ self.with_spectral_norm = with_spectral_norm
96
+ self.with_explicit_padding = padding_mode not in official_padding_mode
97
+ self.order = order
98
+ assert isinstance(self.order, tuple) and len(self.order) == 3
99
+ assert set(order) == set(['conv', 'norm', 'act'])
100
+
101
+ self.with_norm = norm_cfg is not None
102
+ self.with_activation = act_cfg is not None
103
+ # if the conv layer is before a norm layer, bias is unnecessary.
104
+ if bias == 'auto':
105
+ bias = not self.with_norm
106
+ self.with_bias = bias
107
+
108
+ if self.with_explicit_padding:
109
+ pad_cfg = dict(type=padding_mode)
110
+ self.padding_layer = build_padding_layer(pad_cfg, padding)
111
+
112
+ # reset padding to 0 for conv module
113
+ conv_padding = 0 if self.with_explicit_padding else padding
114
+ # build convolution layer
115
+ self.conv = build_conv_layer(
116
+ conv_cfg,
117
+ in_channels,
118
+ out_channels,
119
+ kernel_size,
120
+ stride=stride,
121
+ padding=conv_padding,
122
+ dilation=dilation,
123
+ groups=groups,
124
+ bias=bias)
125
+ # export the attributes of self.conv to a higher level for convenience
126
+ self.in_channels = self.conv.in_channels
127
+ self.out_channels = self.conv.out_channels
128
+ self.kernel_size = self.conv.kernel_size
129
+ self.stride = self.conv.stride
130
+ self.padding = padding
131
+ self.dilation = self.conv.dilation
132
+ self.transposed = self.conv.transposed
133
+ self.output_padding = self.conv.output_padding
134
+ self.groups = self.conv.groups
135
+
136
+ if self.with_spectral_norm:
137
+ self.conv = nn.utils.spectral_norm(self.conv)
138
+
139
+ # build normalization layers
140
+ if self.with_norm:
141
+ # norm layer is after conv layer
142
+ if order.index('norm') > order.index('conv'):
143
+ norm_channels = out_channels
144
+ else:
145
+ norm_channels = in_channels
146
+ self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
147
+ self.add_module(self.norm_name, norm)
148
+ if self.with_bias:
149
+ if isinstance(norm, (_BatchNorm, _InstanceNorm)):
150
+ warnings.warn(
151
+ 'Unnecessary conv bias before batch/instance norm')
152
+ else:
153
+ self.norm_name = None
154
+
155
+ # build activation layer
156
+ if self.with_activation:
157
+ act_cfg_ = act_cfg.copy()
158
+ # nn.Tanh has no 'inplace' argument
159
+ if act_cfg_['type'] not in [
160
+ 'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
161
+ ]:
162
+ act_cfg_.setdefault('inplace', inplace)
163
+ self.activate = build_activation_layer(act_cfg_)
164
+
165
+ # Use msra init by default
166
+ self.init_weights()
167
+
168
+ @property
169
+ def norm(self):
170
+ if self.norm_name:
171
+ return getattr(self, self.norm_name)
172
+ else:
173
+ return None
174
+
175
+ def init_weights(self):
176
+ # 1. It is mainly for customized conv layers with their own
177
+ # initialization manners by calling their own ``init_weights()``,
178
+ # and we do not want ConvModule to override the initialization.
179
+ # 2. For customized conv layers without their own initialization
180
+ # manners (that is, they don't have their own ``init_weights()``)
181
+ # and PyTorch's conv layers, they will be initialized by
182
+ # this method with default ``kaiming_init``.
183
+ # Note: For PyTorch's conv layers, they will be overwritten by our
184
+ # initialization implementation using default ``kaiming_init``.
185
+ if not hasattr(self.conv, 'init_weights'):
186
+ if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
187
+ nonlinearity = 'leaky_relu'
188
+ a = self.act_cfg.get('negative_slope', 0.01)
189
+ else:
190
+ nonlinearity = 'relu'
191
+ a = 0
192
+ kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
193
+ if self.with_norm:
194
+ constant_init(self.norm, 1, bias=0)
195
+
196
+ def forward(self, x, activate=True, norm=True):
197
+ for layer in self.order:
198
+ if layer == 'conv':
199
+ if self.with_explicit_padding:
200
+ x = self.padding_layer(x)
201
+ x = self.conv(x)
202
+ elif layer == 'norm' and norm and self.with_norm:
203
+ x = self.norm(x)
204
+ elif layer == 'act' and activate and self.with_activation:
205
+ x = self.activate(x)
206
+ return x
custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from .registry import CONV_LAYERS
7
+
8
+
9
+ def conv_ws_2d(input,
10
+ weight,
11
+ bias=None,
12
+ stride=1,
13
+ padding=0,
14
+ dilation=1,
15
+ groups=1,
16
+ eps=1e-5):
17
+ c_in = weight.size(0)
18
+ weight_flat = weight.view(c_in, -1)
19
+ mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
20
+ std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
21
+ weight = (weight - mean) / (std + eps)
22
+ return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
23
+
24
+
25
+ @CONV_LAYERS.register_module('ConvWS')
26
+ class ConvWS2d(nn.Conv2d):
27
+
28
+ def __init__(self,
29
+ in_channels,
30
+ out_channels,
31
+ kernel_size,
32
+ stride=1,
33
+ padding=0,
34
+ dilation=1,
35
+ groups=1,
36
+ bias=True,
37
+ eps=1e-5):
38
+ super(ConvWS2d, self).__init__(
39
+ in_channels,
40
+ out_channels,
41
+ kernel_size,
42
+ stride=stride,
43
+ padding=padding,
44
+ dilation=dilation,
45
+ groups=groups,
46
+ bias=bias)
47
+ self.eps = eps
48
+
49
+ def forward(self, x):
50
+ return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
51
+ self.dilation, self.groups, self.eps)
52
+
53
+
54
+ @CONV_LAYERS.register_module(name='ConvAWS')
55
+ class ConvAWS2d(nn.Conv2d):
56
+ """AWS (Adaptive Weight Standardization)
57
+
58
+ This is a variant of Weight Standardization
59
+ (https://arxiv.org/pdf/1903.10520.pdf)
60
+ It is used in DetectoRS to avoid NaN
61
+ (https://arxiv.org/pdf/2006.02334.pdf)
62
+
63
+ Args:
64
+ in_channels (int): Number of channels in the input image
65
+ out_channels (int): Number of channels produced by the convolution
66
+ kernel_size (int or tuple): Size of the conv kernel
67
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
68
+ padding (int or tuple, optional): Zero-padding added to both sides of
69
+ the input. Default: 0
70
+ dilation (int or tuple, optional): Spacing between kernel elements.
71
+ Default: 1
72
+ groups (int, optional): Number of blocked connections from input
73
+ channels to output channels. Default: 1
74
+ bias (bool, optional): If set True, adds a learnable bias to the
75
+ output. Default: True
76
+ """
77
+
78
+ def __init__(self,
79
+ in_channels,
80
+ out_channels,
81
+ kernel_size,
82
+ stride=1,
83
+ padding=0,
84
+ dilation=1,
85
+ groups=1,
86
+ bias=True):
87
+ super().__init__(
88
+ in_channels,
89
+ out_channels,
90
+ kernel_size,
91
+ stride=stride,
92
+ padding=padding,
93
+ dilation=dilation,
94
+ groups=groups,
95
+ bias=bias)
96
+ self.register_buffer('weight_gamma',
97
+ torch.ones(self.out_channels, 1, 1, 1))
98
+ self.register_buffer('weight_beta',
99
+ torch.zeros(self.out_channels, 1, 1, 1))
100
+
101
+ def _get_weight(self, weight):
102
+ weight_flat = weight.view(weight.size(0), -1)
103
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
104
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
105
+ weight = (weight - mean) / std
106
+ weight = self.weight_gamma * weight + self.weight_beta
107
+ return weight
108
+
109
+ def forward(self, x):
110
+ weight = self._get_weight(self.weight)
111
+ return F.conv2d(x, weight, self.bias, self.stride, self.padding,
112
+ self.dilation, self.groups)
113
+
114
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
115
+ missing_keys, unexpected_keys, error_msgs):
116
+ """Override default load function.
117
+
118
+ AWS overrides the function _load_from_state_dict to recover
119
+ weight_gamma and weight_beta if they are missing. If weight_gamma and
120
+ weight_beta are found in the checkpoint, this function will return
121
+ after super()._load_from_state_dict. Otherwise, it will compute the
122
+ mean and std of the pretrained weights and store them in weight_beta
123
+ and weight_gamma.
124
+ """
125
+
126
+ self.weight_gamma.data.fill_(-1)
127
+ local_missing_keys = []
128
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
129
+ strict, local_missing_keys,
130
+ unexpected_keys, error_msgs)
131
+ if self.weight_gamma.data.mean() > 0:
132
+ for k in local_missing_keys:
133
+ missing_keys.append(k)
134
+ return
135
+ weight = self.weight.data
136
+ weight_flat = weight.view(weight.size(0), -1)
137
+ mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
138
+ std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
139
+ self.weight_beta.data.copy_(mean)
140
+ self.weight_gamma.data.copy_(std)
141
+ missing_gamma_beta = [
142
+ k for k in local_missing_keys
143
+ if k.endswith('weight_gamma') or k.endswith('weight_beta')
144
+ ]
145
+ for k in missing_gamma_beta:
146
+ local_missing_keys.remove(k)
147
+ for k in local_missing_keys:
148
+ missing_keys.append(k)
custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .conv_module import ConvModule
5
+
6
+
7
+ class DepthwiseSeparableConvModule(nn.Module):
8
+ """Depthwise separable convolution module.
9
+
10
+ See https://arxiv.org/pdf/1704.04861.pdf for details.
11
+
12
+ This module can replace a ConvModule with the conv block replaced by two
13
+ conv block: depthwise conv block and pointwise conv block. The depthwise
14
+ conv block contains depthwise-conv/norm/activation layers. The pointwise
15
+ conv block contains pointwise-conv/norm/activation layers. It should be
16
+ noted that there will be norm/activation layer in the depthwise conv block
17
+ if `norm_cfg` and `act_cfg` are specified.
18
+
19
+ Args:
20
+ in_channels (int): Number of channels in the input feature map.
21
+ Same as that in ``nn._ConvNd``.
22
+ out_channels (int): Number of channels produced by the convolution.
23
+ Same as that in ``nn._ConvNd``.
24
+ kernel_size (int | tuple[int]): Size of the convolving kernel.
25
+ Same as that in ``nn._ConvNd``.
26
+ stride (int | tuple[int]): Stride of the convolution.
27
+ Same as that in ``nn._ConvNd``. Default: 1.
28
+ padding (int | tuple[int]): Zero-padding added to both sides of
29
+ the input. Same as that in ``nn._ConvNd``. Default: 0.
30
+ dilation (int | tuple[int]): Spacing between kernel elements.
31
+ Same as that in ``nn._ConvNd``. Default: 1.
32
+ norm_cfg (dict): Default norm config for both depthwise ConvModule and
33
+ pointwise ConvModule. Default: None.
34
+ act_cfg (dict): Default activation config for both depthwise ConvModule
35
+ and pointwise ConvModule. Default: dict(type='ReLU').
36
+ dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
37
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
38
+ dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
39
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
40
+ pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
41
+ 'default', it will be the same as `norm_cfg`. Default: 'default'.
42
+ pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
43
+ 'default', it will be the same as `act_cfg`. Default: 'default'.
44
+ kwargs (optional): Other shared arguments for depthwise and pointwise
45
+ ConvModule. See ConvModule for ref.
46
+ """
47
+
48
+ def __init__(self,
49
+ in_channels,
50
+ out_channels,
51
+ kernel_size,
52
+ stride=1,
53
+ padding=0,
54
+ dilation=1,
55
+ norm_cfg=None,
56
+ act_cfg=dict(type='ReLU'),
57
+ dw_norm_cfg='default',
58
+ dw_act_cfg='default',
59
+ pw_norm_cfg='default',
60
+ pw_act_cfg='default',
61
+ **kwargs):
62
+ super(DepthwiseSeparableConvModule, self).__init__()
63
+ assert 'groups' not in kwargs, 'groups should not be specified'
64
+
65
+ # if norm/activation config of depthwise/pointwise ConvModule is not
66
+ # specified, use default config.
67
+ dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
68
+ dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
69
+ pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
70
+ pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
71
+
72
+ # depthwise convolution
73
+ self.depthwise_conv = ConvModule(
74
+ in_channels,
75
+ in_channels,
76
+ kernel_size,
77
+ stride=stride,
78
+ padding=padding,
79
+ dilation=dilation,
80
+ groups=in_channels,
81
+ norm_cfg=dw_norm_cfg,
82
+ act_cfg=dw_act_cfg,
83
+ **kwargs)
84
+
85
+ self.pointwise_conv = ConvModule(
86
+ in_channels,
87
+ out_channels,
88
+ 1,
89
+ norm_cfg=pw_norm_cfg,
90
+ act_cfg=pw_act_cfg,
91
+ **kwargs)
92
+
93
+ def forward(self, x):
94
+ x = self.depthwise_conv(x)
95
+ x = self.pointwise_conv(x)
96
+ return x
custom_mmpkg/custom_mmcv/cnn/bricks/drop.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from custom_mmpkg.custom_mmcv import build_from_cfg
6
+ from .registry import DROPOUT_LAYERS
7
+
8
+
9
+ def drop_path(x, drop_prob=0., training=False):
10
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
11
+ residual blocks).
12
+
13
+ We follow the implementation
14
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
15
+ """
16
+ if drop_prob == 0. or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ # handle tensors with different dimensions, not just 4D tensors.
20
+ shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
21
+ random_tensor = keep_prob + torch.rand(
22
+ shape, dtype=x.dtype, device=x.device)
23
+ output = x.div(keep_prob) * random_tensor.floor()
24
+ return output
25
+
26
+
27
+ @DROPOUT_LAYERS.register_module()
28
+ class DropPath(nn.Module):
29
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
30
+ residual blocks).
31
+
32
+ We follow the implementation
33
+ https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
34
+
35
+ Args:
36
+ drop_prob (float): Probability of the path to be zeroed. Default: 0.1
37
+ """
38
+
39
+ def __init__(self, drop_prob=0.1):
40
+ super(DropPath, self).__init__()
41
+ self.drop_prob = drop_prob
42
+
43
+ def forward(self, x):
44
+ return drop_path(x, self.drop_prob, self.training)
45
+
46
+
47
+ @DROPOUT_LAYERS.register_module()
48
+ class Dropout(nn.Dropout):
49
+ """A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
50
+ ``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
51
+ ``DropPath``
52
+
53
+ Args:
54
+ drop_prob (float): Probability of the elements to be
55
+ zeroed. Default: 0.5.
56
+ inplace (bool): Do the operation inplace or not. Default: False.
57
+ """
58
+
59
+ def __init__(self, drop_prob=0.5, inplace=False):
60
+ super().__init__(p=drop_prob, inplace=inplace)
61
+
62
+
63
+ def build_dropout(cfg, default_args=None):
64
+ """Builder for drop out layers."""
65
+ return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import math
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from ..utils import kaiming_init
10
+ from .registry import PLUGIN_LAYERS
11
+
12
+
13
+ @PLUGIN_LAYERS.register_module()
14
+ class GeneralizedAttention(nn.Module):
15
+ """GeneralizedAttention module.
16
+
17
+ See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
18
+ (https://arxiv.org/abs/1711.07971) for details.
19
+
20
+ Args:
21
+ in_channels (int): Channels of the input feature map.
22
+ spatial_range (int): The spatial range. -1 indicates no spatial range
23
+ constraint. Default: -1.
24
+ num_heads (int): The head number of empirical_attention module.
25
+ Default: 9.
26
+ position_embedding_dim (int): The position embedding dimension.
27
+ Default: -1.
28
+ position_magnitude (int): A multiplier acting on coord difference.
29
+ Default: 1.
30
+ kv_stride (int): The feature stride acting on key/value feature map.
31
+ Default: 2.
32
+ q_stride (int): The feature stride acting on query feature map.
33
+ Default: 1.
34
+ attention_type (str): A binary indicator string for indicating which
35
+ items in generalized empirical_attention module are used.
36
+ Default: '1111'.
37
+
38
+ - '1000' indicates 'query and key content' (appr - appr) item,
39
+ - '0100' indicates 'query content and relative position'
40
+ (appr - position) item,
41
+ - '0010' indicates 'key content only' (bias - appr) item,
42
+ - '0001' indicates 'relative position only' (bias - position) item.
43
+ """
44
+
45
+ _abbr_ = 'gen_attention_block'
46
+
47
+ def __init__(self,
48
+ in_channels,
49
+ spatial_range=-1,
50
+ num_heads=9,
51
+ position_embedding_dim=-1,
52
+ position_magnitude=1,
53
+ kv_stride=2,
54
+ q_stride=1,
55
+ attention_type='1111'):
56
+
57
+ super(GeneralizedAttention, self).__init__()
58
+
59
+ # hard range means local range for non-local operation
60
+ self.position_embedding_dim = (
61
+ position_embedding_dim
62
+ if position_embedding_dim > 0 else in_channels)
63
+
64
+ self.position_magnitude = position_magnitude
65
+ self.num_heads = num_heads
66
+ self.in_channels = in_channels
67
+ self.spatial_range = spatial_range
68
+ self.kv_stride = kv_stride
69
+ self.q_stride = q_stride
70
+ self.attention_type = [bool(int(_)) for _ in attention_type]
71
+ self.qk_embed_dim = in_channels // num_heads
72
+ out_c = self.qk_embed_dim * num_heads
73
+
74
+ if self.attention_type[0] or self.attention_type[1]:
75
+ self.query_conv = nn.Conv2d(
76
+ in_channels=in_channels,
77
+ out_channels=out_c,
78
+ kernel_size=1,
79
+ bias=False)
80
+ self.query_conv.kaiming_init = True
81
+
82
+ if self.attention_type[0] or self.attention_type[2]:
83
+ self.key_conv = nn.Conv2d(
84
+ in_channels=in_channels,
85
+ out_channels=out_c,
86
+ kernel_size=1,
87
+ bias=False)
88
+ self.key_conv.kaiming_init = True
89
+
90
+ self.v_dim = in_channels // num_heads
91
+ self.value_conv = nn.Conv2d(
92
+ in_channels=in_channels,
93
+ out_channels=self.v_dim * num_heads,
94
+ kernel_size=1,
95
+ bias=False)
96
+ self.value_conv.kaiming_init = True
97
+
98
+ if self.attention_type[1] or self.attention_type[3]:
99
+ self.appr_geom_fc_x = nn.Linear(
100
+ self.position_embedding_dim // 2, out_c, bias=False)
101
+ self.appr_geom_fc_x.kaiming_init = True
102
+
103
+ self.appr_geom_fc_y = nn.Linear(
104
+ self.position_embedding_dim // 2, out_c, bias=False)
105
+ self.appr_geom_fc_y.kaiming_init = True
106
+
107
+ if self.attention_type[2]:
108
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
109
+ appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
110
+ self.appr_bias = nn.Parameter(appr_bias_value)
111
+
112
+ if self.attention_type[3]:
113
+ stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
114
+ geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
115
+ self.geom_bias = nn.Parameter(geom_bias_value)
116
+
117
+ self.proj_conv = nn.Conv2d(
118
+ in_channels=self.v_dim * num_heads,
119
+ out_channels=in_channels,
120
+ kernel_size=1,
121
+ bias=True)
122
+ self.proj_conv.kaiming_init = True
123
+ self.gamma = nn.Parameter(torch.zeros(1))
124
+
125
+ if self.spatial_range >= 0:
126
+ # only works when non local is after 3*3 conv
127
+ if in_channels == 256:
128
+ max_len = 84
129
+ elif in_channels == 512:
130
+ max_len = 42
131
+
132
+ max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
133
+ local_constraint_map = np.ones(
134
+ (max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
135
+ for iy in range(max_len):
136
+ for ix in range(max_len):
137
+ local_constraint_map[
138
+ iy, ix,
139
+ max((iy - self.spatial_range) //
140
+ self.kv_stride, 0):min((iy + self.spatial_range +
141
+ 1) // self.kv_stride +
142
+ 1, max_len),
143
+ max((ix - self.spatial_range) //
144
+ self.kv_stride, 0):min((ix + self.spatial_range +
145
+ 1) // self.kv_stride +
146
+ 1, max_len)] = 0
147
+
148
+ self.local_constraint_map = nn.Parameter(
149
+ torch.from_numpy(local_constraint_map).byte(),
150
+ requires_grad=False)
151
+
152
+ if self.q_stride > 1:
153
+ self.q_downsample = nn.AvgPool2d(
154
+ kernel_size=1, stride=self.q_stride)
155
+ else:
156
+ self.q_downsample = None
157
+
158
+ if self.kv_stride > 1:
159
+ self.kv_downsample = nn.AvgPool2d(
160
+ kernel_size=1, stride=self.kv_stride)
161
+ else:
162
+ self.kv_downsample = None
163
+
164
+ self.init_weights()
165
+
166
+ def get_position_embedding(self,
167
+ h,
168
+ w,
169
+ h_kv,
170
+ w_kv,
171
+ q_stride,
172
+ kv_stride,
173
+ device,
174
+ dtype,
175
+ feat_dim,
176
+ wave_length=1000):
177
+ # the default type of Tensor is float32, leading to type mismatch
178
+ # in fp16 mode. Cast it to support fp16 mode.
179
+ h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
180
+ h_idxs = h_idxs.view((h, 1)) * q_stride
181
+
182
+ w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
183
+ w_idxs = w_idxs.view((w, 1)) * q_stride
184
+
185
+ h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
186
+ device=device, dtype=dtype)
187
+ h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
188
+
189
+ w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
190
+ device=device, dtype=dtype)
191
+ w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
192
+
193
+ # (h, h_kv, 1)
194
+ h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
195
+ h_diff *= self.position_magnitude
196
+
197
+ # (w, w_kv, 1)
198
+ w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
199
+ w_diff *= self.position_magnitude
200
+
201
+ feat_range = torch.arange(0, feat_dim / 4).to(
202
+ device=device, dtype=dtype)
203
+
204
+ dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
205
+ dim_mat = dim_mat**((4. / feat_dim) * feat_range)
206
+ dim_mat = dim_mat.view((1, 1, -1))
207
+
208
+ embedding_x = torch.cat(
209
+ ((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
210
+
211
+ embedding_y = torch.cat(
212
+ ((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
213
+
214
+ return embedding_x, embedding_y
215
+
216
+ def forward(self, x_input):
217
+ num_heads = self.num_heads
218
+
219
+ # use empirical_attention
220
+ if self.q_downsample is not None:
221
+ x_q = self.q_downsample(x_input)
222
+ else:
223
+ x_q = x_input
224
+ n, _, h, w = x_q.shape
225
+
226
+ if self.kv_downsample is not None:
227
+ x_kv = self.kv_downsample(x_input)
228
+ else:
229
+ x_kv = x_input
230
+ _, _, h_kv, w_kv = x_kv.shape
231
+
232
+ if self.attention_type[0] or self.attention_type[1]:
233
+ proj_query = self.query_conv(x_q).view(
234
+ (n, num_heads, self.qk_embed_dim, h * w))
235
+ proj_query = proj_query.permute(0, 1, 3, 2)
236
+
237
+ if self.attention_type[0] or self.attention_type[2]:
238
+ proj_key = self.key_conv(x_kv).view(
239
+ (n, num_heads, self.qk_embed_dim, h_kv * w_kv))
240
+
241
+ if self.attention_type[1] or self.attention_type[3]:
242
+ position_embed_x, position_embed_y = self.get_position_embedding(
243
+ h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
244
+ x_input.device, x_input.dtype, self.position_embedding_dim)
245
+ # (n, num_heads, w, w_kv, dim)
246
+ position_feat_x = self.appr_geom_fc_x(position_embed_x).\
247
+ view(1, w, w_kv, num_heads, self.qk_embed_dim).\
248
+ permute(0, 3, 1, 2, 4).\
249
+ repeat(n, 1, 1, 1, 1)
250
+
251
+ # (n, num_heads, h, h_kv, dim)
252
+ position_feat_y = self.appr_geom_fc_y(position_embed_y).\
253
+ view(1, h, h_kv, num_heads, self.qk_embed_dim).\
254
+ permute(0, 3, 1, 2, 4).\
255
+ repeat(n, 1, 1, 1, 1)
256
+
257
+ position_feat_x /= math.sqrt(2)
258
+ position_feat_y /= math.sqrt(2)
259
+
260
+ # accelerate for saliency only
261
+ if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
262
+ appr_bias = self.appr_bias.\
263
+ view(1, num_heads, 1, self.qk_embed_dim).\
264
+ repeat(n, 1, 1, 1)
265
+
266
+ energy = torch.matmul(appr_bias, proj_key).\
267
+ view(n, num_heads, 1, h_kv * w_kv)
268
+
269
+ h = 1
270
+ w = 1
271
+ else:
272
+ # (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
273
+ if not self.attention_type[0]:
274
+ energy = torch.zeros(
275
+ n,
276
+ num_heads,
277
+ h,
278
+ w,
279
+ h_kv,
280
+ w_kv,
281
+ dtype=x_input.dtype,
282
+ device=x_input.device)
283
+
284
+ # attention_type[0]: appr - appr
285
+ # attention_type[1]: appr - position
286
+ # attention_type[2]: bias - appr
287
+ # attention_type[3]: bias - position
288
+ if self.attention_type[0] or self.attention_type[2]:
289
+ if self.attention_type[0] and self.attention_type[2]:
290
+ appr_bias = self.appr_bias.\
291
+ view(1, num_heads, 1, self.qk_embed_dim)
292
+ energy = torch.matmul(proj_query + appr_bias, proj_key).\
293
+ view(n, num_heads, h, w, h_kv, w_kv)
294
+
295
+ elif self.attention_type[0]:
296
+ energy = torch.matmul(proj_query, proj_key).\
297
+ view(n, num_heads, h, w, h_kv, w_kv)
298
+
299
+ elif self.attention_type[2]:
300
+ appr_bias = self.appr_bias.\
301
+ view(1, num_heads, 1, self.qk_embed_dim).\
302
+ repeat(n, 1, 1, 1)
303
+
304
+ energy += torch.matmul(appr_bias, proj_key).\
305
+ view(n, num_heads, 1, 1, h_kv, w_kv)
306
+
307
+ if self.attention_type[1] or self.attention_type[3]:
308
+ if self.attention_type[1] and self.attention_type[3]:
309
+ geom_bias = self.geom_bias.\
310
+ view(1, num_heads, 1, self.qk_embed_dim)
311
+
312
+ proj_query_reshape = (proj_query + geom_bias).\
313
+ view(n, num_heads, h, w, self.qk_embed_dim)
314
+
315
+ energy_x = torch.matmul(
316
+ proj_query_reshape.permute(0, 1, 3, 2, 4),
317
+ position_feat_x.permute(0, 1, 2, 4, 3))
318
+ energy_x = energy_x.\
319
+ permute(0, 1, 3, 2, 4).unsqueeze(4)
320
+
321
+ energy_y = torch.matmul(
322
+ proj_query_reshape,
323
+ position_feat_y.permute(0, 1, 2, 4, 3))
324
+ energy_y = energy_y.unsqueeze(5)
325
+
326
+ energy += energy_x + energy_y
327
+
328
+ elif self.attention_type[1]:
329
+ proj_query_reshape = proj_query.\
330
+ view(n, num_heads, h, w, self.qk_embed_dim)
331
+ proj_query_reshape = proj_query_reshape.\
332
+ permute(0, 1, 3, 2, 4)
333
+ position_feat_x_reshape = position_feat_x.\
334
+ permute(0, 1, 2, 4, 3)
335
+ position_feat_y_reshape = position_feat_y.\
336
+ permute(0, 1, 2, 4, 3)
337
+
338
+ energy_x = torch.matmul(proj_query_reshape,
339
+ position_feat_x_reshape)
340
+ energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
341
+
342
+ energy_y = torch.matmul(proj_query_reshape,
343
+ position_feat_y_reshape)
344
+ energy_y = energy_y.unsqueeze(5)
345
+
346
+ energy += energy_x + energy_y
347
+
348
+ elif self.attention_type[3]:
349
+ geom_bias = self.geom_bias.\
350
+ view(1, num_heads, self.qk_embed_dim, 1).\
351
+ repeat(n, 1, 1, 1)
352
+
353
+ position_feat_x_reshape = position_feat_x.\
354
+ view(n, num_heads, w*w_kv, self.qk_embed_dim)
355
+
356
+ position_feat_y_reshape = position_feat_y.\
357
+ view(n, num_heads, h * h_kv, self.qk_embed_dim)
358
+
359
+ energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
360
+ energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
361
+
362
+ energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
363
+ energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
364
+
365
+ energy += energy_x + energy_y
366
+
367
+ energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
368
+
369
+ if self.spatial_range >= 0:
370
+ cur_local_constraint_map = \
371
+ self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
372
+ contiguous().\
373
+ view(1, 1, h*w, h_kv*w_kv)
374
+
375
+ energy = energy.masked_fill_(cur_local_constraint_map,
376
+ float('-inf'))
377
+
378
+ attention = F.softmax(energy, 3)
379
+
380
+ proj_value = self.value_conv(x_kv)
381
+ proj_value_reshape = proj_value.\
382
+ view((n, num_heads, self.v_dim, h_kv * w_kv)).\
383
+ permute(0, 1, 3, 2)
384
+
385
+ out = torch.matmul(attention, proj_value_reshape).\
386
+ permute(0, 1, 3, 2).\
387
+ contiguous().\
388
+ view(n, self.v_dim * self.num_heads, h, w)
389
+
390
+ out = self.proj_conv(out)
391
+
392
+ # output is downsampled, upsample back to input size
393
+ if self.q_downsample is not None:
394
+ out = F.interpolate(
395
+ out,
396
+ size=x_input.shape[2:],
397
+ mode='bilinear',
398
+ align_corners=False)
399
+
400
+ out = self.gamma * out + x_input
401
+ return out
402
+
403
+ def init_weights(self):
404
+ for m in self.modules():
405
+ if hasattr(m, 'kaiming_init') and m.kaiming_init:
406
+ kaiming_init(
407
+ m,
408
+ mode='fan_in',
409
+ nonlinearity='leaky_relu',
410
+ bias=0,
411
+ distribution='uniform',
412
+ a=1)
custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import ACTIVATION_LAYERS
5
+
6
+
7
+ @ACTIVATION_LAYERS.register_module()
8
+ class HSigmoid(nn.Module):
9
+ """Hard Sigmoid Module. Apply the hard sigmoid function:
10
+ Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
11
+ Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
12
+
13
+ Args:
14
+ bias (float): Bias of the input feature map. Default: 1.0.
15
+ divisor (float): Divisor of the input feature map. Default: 2.0.
16
+ min_value (float): Lower bound value. Default: 0.0.
17
+ max_value (float): Upper bound value. Default: 1.0.
18
+
19
+ Returns:
20
+ Tensor: The output tensor.
21
+ """
22
+
23
+ def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
24
+ super(HSigmoid, self).__init__()
25
+ self.bias = bias
26
+ self.divisor = divisor
27
+ assert self.divisor != 0
28
+ self.min_value = min_value
29
+ self.max_value = max_value
30
+
31
+ def forward(self, x):
32
+ x = (x + self.bias) / self.divisor
33
+
34
+ return x.clamp_(self.min_value, self.max_value)
custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import ACTIVATION_LAYERS
5
+
6
+
7
+ @ACTIVATION_LAYERS.register_module()
8
+ class HSwish(nn.Module):
9
+ """Hard Swish Module.
10
+
11
+ This module applies the hard swish function:
12
+
13
+ .. math::
14
+ Hswish(x) = x * ReLU6(x + 3) / 6
15
+
16
+ Args:
17
+ inplace (bool): can optionally do the operation in-place.
18
+ Default: False.
19
+
20
+ Returns:
21
+ Tensor: The output tensor.
22
+ """
23
+
24
+ def __init__(self, inplace=False):
25
+ super(HSwish, self).__init__()
26
+ self.act = nn.ReLU6(inplace)
27
+
28
+ def forward(self, x):
29
+ return x * self.act(x + 3) / 6
custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from abc import ABCMeta
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from ..utils import constant_init, normal_init
8
+ from .conv_module import ConvModule
9
+ from .registry import PLUGIN_LAYERS
10
+
11
+
12
+ class _NonLocalNd(nn.Module, metaclass=ABCMeta):
13
+ """Basic Non-local module.
14
+
15
+ This module is proposed in
16
+ "Non-local Neural Networks"
17
+ Paper reference: https://arxiv.org/abs/1711.07971
18
+ Code reference: https://github.com/AlexHex7/Non-local_pytorch
19
+
20
+ Args:
21
+ in_channels (int): Channels of the input feature map.
22
+ reduction (int): Channel reduction ratio. Default: 2.
23
+ use_scale (bool): Whether to scale pairwise_weight by
24
+ `1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
25
+ Default: True.
26
+ conv_cfg (None | dict): The config dict for convolution layers.
27
+ If not specified, it will use `nn.Conv2d` for convolution layers.
28
+ Default: None.
29
+ norm_cfg (None | dict): The config dict for normalization layers.
30
+ Default: None. (This parameter is only applicable to conv_out.)
31
+ mode (str): Options are `gaussian`, `concatenation`,
32
+ `embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
33
+ """
34
+
35
+ def __init__(self,
36
+ in_channels,
37
+ reduction=2,
38
+ use_scale=True,
39
+ conv_cfg=None,
40
+ norm_cfg=None,
41
+ mode='embedded_gaussian',
42
+ **kwargs):
43
+ super(_NonLocalNd, self).__init__()
44
+ self.in_channels = in_channels
45
+ self.reduction = reduction
46
+ self.use_scale = use_scale
47
+ self.inter_channels = max(in_channels // reduction, 1)
48
+ self.mode = mode
49
+
50
+ if mode not in [
51
+ 'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
52
+ ]:
53
+ raise ValueError("Mode should be in 'gaussian', 'concatenation', "
54
+ f"'embedded_gaussian' or 'dot_product', but got "
55
+ f'{mode} instead.')
56
+
57
+ # g, theta, phi are defaulted as `nn.ConvNd`.
58
+ # Here we use ConvModule for potential usage.
59
+ self.g = ConvModule(
60
+ self.in_channels,
61
+ self.inter_channels,
62
+ kernel_size=1,
63
+ conv_cfg=conv_cfg,
64
+ act_cfg=None)
65
+ self.conv_out = ConvModule(
66
+ self.inter_channels,
67
+ self.in_channels,
68
+ kernel_size=1,
69
+ conv_cfg=conv_cfg,
70
+ norm_cfg=norm_cfg,
71
+ act_cfg=None)
72
+
73
+ if self.mode != 'gaussian':
74
+ self.theta = ConvModule(
75
+ self.in_channels,
76
+ self.inter_channels,
77
+ kernel_size=1,
78
+ conv_cfg=conv_cfg,
79
+ act_cfg=None)
80
+ self.phi = ConvModule(
81
+ self.in_channels,
82
+ self.inter_channels,
83
+ kernel_size=1,
84
+ conv_cfg=conv_cfg,
85
+ act_cfg=None)
86
+
87
+ if self.mode == 'concatenation':
88
+ self.concat_project = ConvModule(
89
+ self.inter_channels * 2,
90
+ 1,
91
+ kernel_size=1,
92
+ stride=1,
93
+ padding=0,
94
+ bias=False,
95
+ act_cfg=dict(type='ReLU'))
96
+
97
+ self.init_weights(**kwargs)
98
+
99
+ def init_weights(self, std=0.01, zeros_init=True):
100
+ if self.mode != 'gaussian':
101
+ for m in [self.g, self.theta, self.phi]:
102
+ normal_init(m.conv, std=std)
103
+ else:
104
+ normal_init(self.g.conv, std=std)
105
+ if zeros_init:
106
+ if self.conv_out.norm_cfg is None:
107
+ constant_init(self.conv_out.conv, 0)
108
+ else:
109
+ constant_init(self.conv_out.norm, 0)
110
+ else:
111
+ if self.conv_out.norm_cfg is None:
112
+ normal_init(self.conv_out.conv, std=std)
113
+ else:
114
+ normal_init(self.conv_out.norm, std=std)
115
+
116
+ def gaussian(self, theta_x, phi_x):
117
+ # NonLocal1d pairwise_weight: [N, H, H]
118
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
119
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
120
+ pairwise_weight = torch.matmul(theta_x, phi_x)
121
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
122
+ return pairwise_weight
123
+
124
+ def embedded_gaussian(self, theta_x, phi_x):
125
+ # NonLocal1d pairwise_weight: [N, H, H]
126
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
127
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
128
+ pairwise_weight = torch.matmul(theta_x, phi_x)
129
+ if self.use_scale:
130
+ # theta_x.shape[-1] is `self.inter_channels`
131
+ pairwise_weight /= theta_x.shape[-1]**0.5
132
+ pairwise_weight = pairwise_weight.softmax(dim=-1)
133
+ return pairwise_weight
134
+
135
+ def dot_product(self, theta_x, phi_x):
136
+ # NonLocal1d pairwise_weight: [N, H, H]
137
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
138
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
139
+ pairwise_weight = torch.matmul(theta_x, phi_x)
140
+ pairwise_weight /= pairwise_weight.shape[-1]
141
+ return pairwise_weight
142
+
143
+ def concatenation(self, theta_x, phi_x):
144
+ # NonLocal1d pairwise_weight: [N, H, H]
145
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
146
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
147
+ h = theta_x.size(2)
148
+ w = phi_x.size(3)
149
+ theta_x = theta_x.repeat(1, 1, 1, w)
150
+ phi_x = phi_x.repeat(1, 1, h, 1)
151
+
152
+ concat_feature = torch.cat([theta_x, phi_x], dim=1)
153
+ pairwise_weight = self.concat_project(concat_feature)
154
+ n, _, h, w = pairwise_weight.size()
155
+ pairwise_weight = pairwise_weight.view(n, h, w)
156
+ pairwise_weight /= pairwise_weight.shape[-1]
157
+
158
+ return pairwise_weight
159
+
160
+ def forward(self, x):
161
+ # Assume `reduction = 1`, then `inter_channels = C`
162
+ # or `inter_channels = C` when `mode="gaussian"`
163
+
164
+ # NonLocal1d x: [N, C, H]
165
+ # NonLocal2d x: [N, C, H, W]
166
+ # NonLocal3d x: [N, C, T, H, W]
167
+ n = x.size(0)
168
+
169
+ # NonLocal1d g_x: [N, H, C]
170
+ # NonLocal2d g_x: [N, HxW, C]
171
+ # NonLocal3d g_x: [N, TxHxW, C]
172
+ g_x = self.g(x).view(n, self.inter_channels, -1)
173
+ g_x = g_x.permute(0, 2, 1)
174
+
175
+ # NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
176
+ # NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
177
+ # NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
178
+ if self.mode == 'gaussian':
179
+ theta_x = x.view(n, self.in_channels, -1)
180
+ theta_x = theta_x.permute(0, 2, 1)
181
+ if self.sub_sample:
182
+ phi_x = self.phi(x).view(n, self.in_channels, -1)
183
+ else:
184
+ phi_x = x.view(n, self.in_channels, -1)
185
+ elif self.mode == 'concatenation':
186
+ theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
187
+ phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
188
+ else:
189
+ theta_x = self.theta(x).view(n, self.inter_channels, -1)
190
+ theta_x = theta_x.permute(0, 2, 1)
191
+ phi_x = self.phi(x).view(n, self.inter_channels, -1)
192
+
193
+ pairwise_func = getattr(self, self.mode)
194
+ # NonLocal1d pairwise_weight: [N, H, H]
195
+ # NonLocal2d pairwise_weight: [N, HxW, HxW]
196
+ # NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
197
+ pairwise_weight = pairwise_func(theta_x, phi_x)
198
+
199
+ # NonLocal1d y: [N, H, C]
200
+ # NonLocal2d y: [N, HxW, C]
201
+ # NonLocal3d y: [N, TxHxW, C]
202
+ y = torch.matmul(pairwise_weight, g_x)
203
+ # NonLocal1d y: [N, C, H]
204
+ # NonLocal2d y: [N, C, H, W]
205
+ # NonLocal3d y: [N, C, T, H, W]
206
+ y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
207
+ *x.size()[2:])
208
+
209
+ output = x + self.conv_out(y)
210
+
211
+ return output
212
+
213
+
214
+ class NonLocal1d(_NonLocalNd):
215
+ """1D Non-local module.
216
+
217
+ Args:
218
+ in_channels (int): Same as `NonLocalND`.
219
+ sub_sample (bool): Whether to apply max pooling after pairwise
220
+ function (Note that the `sub_sample` is applied on spatial only).
221
+ Default: False.
222
+ conv_cfg (None | dict): Same as `NonLocalND`.
223
+ Default: dict(type='Conv1d').
224
+ """
225
+
226
+ def __init__(self,
227
+ in_channels,
228
+ sub_sample=False,
229
+ conv_cfg=dict(type='Conv1d'),
230
+ **kwargs):
231
+ super(NonLocal1d, self).__init__(
232
+ in_channels, conv_cfg=conv_cfg, **kwargs)
233
+
234
+ self.sub_sample = sub_sample
235
+
236
+ if sub_sample:
237
+ max_pool_layer = nn.MaxPool1d(kernel_size=2)
238
+ self.g = nn.Sequential(self.g, max_pool_layer)
239
+ if self.mode != 'gaussian':
240
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
241
+ else:
242
+ self.phi = max_pool_layer
243
+
244
+
245
+ @PLUGIN_LAYERS.register_module()
246
+ class NonLocal2d(_NonLocalNd):
247
+ """2D Non-local module.
248
+
249
+ Args:
250
+ in_channels (int): Same as `NonLocalND`.
251
+ sub_sample (bool): Whether to apply max pooling after pairwise
252
+ function (Note that the `sub_sample` is applied on spatial only).
253
+ Default: False.
254
+ conv_cfg (None | dict): Same as `NonLocalND`.
255
+ Default: dict(type='Conv2d').
256
+ """
257
+
258
+ _abbr_ = 'nonlocal_block'
259
+
260
+ def __init__(self,
261
+ in_channels,
262
+ sub_sample=False,
263
+ conv_cfg=dict(type='Conv2d'),
264
+ **kwargs):
265
+ super(NonLocal2d, self).__init__(
266
+ in_channels, conv_cfg=conv_cfg, **kwargs)
267
+
268
+ self.sub_sample = sub_sample
269
+
270
+ if sub_sample:
271
+ max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
272
+ self.g = nn.Sequential(self.g, max_pool_layer)
273
+ if self.mode != 'gaussian':
274
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
275
+ else:
276
+ self.phi = max_pool_layer
277
+
278
+
279
+ class NonLocal3d(_NonLocalNd):
280
+ """3D Non-local module.
281
+
282
+ Args:
283
+ in_channels (int): Same as `NonLocalND`.
284
+ sub_sample (bool): Whether to apply max pooling after pairwise
285
+ function (Note that the `sub_sample` is applied on spatial only).
286
+ Default: False.
287
+ conv_cfg (None | dict): Same as `NonLocalND`.
288
+ Default: dict(type='Conv3d').
289
+ """
290
+
291
+ def __init__(self,
292
+ in_channels,
293
+ sub_sample=False,
294
+ conv_cfg=dict(type='Conv3d'),
295
+ **kwargs):
296
+ super(NonLocal3d, self).__init__(
297
+ in_channels, conv_cfg=conv_cfg, **kwargs)
298
+ self.sub_sample = sub_sample
299
+
300
+ if sub_sample:
301
+ max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
302
+ self.g = nn.Sequential(self.g, max_pool_layer)
303
+ if self.mode != 'gaussian':
304
+ self.phi = nn.Sequential(self.phi, max_pool_layer)
305
+ else:
306
+ self.phi = max_pool_layer
custom_mmpkg/custom_mmcv/cnn/bricks/norm.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import inspect
3
+
4
+ import torch.nn as nn
5
+
6
+ from custom_mmpkg.custom_mmcv.utils import is_tuple_of
7
+ from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
8
+ from .registry import NORM_LAYERS
9
+
10
+ NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
11
+ NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
12
+ NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
13
+ NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
14
+ NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
15
+ NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
16
+ NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
17
+ NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
18
+ NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
19
+ NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
20
+ NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
21
+
22
+
23
+ def infer_abbr(class_type):
24
+ """Infer abbreviation from the class name.
25
+
26
+ When we build a norm layer with `build_norm_layer()`, we want to preserve
27
+ the norm type in variable names, e.g, self.bn1, self.gn. This method will
28
+ infer the abbreviation to map class types to abbreviations.
29
+
30
+ Rule 1: If the class has the property "_abbr_", return the property.
31
+ Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
32
+ InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
33
+ "in" respectively.
34
+ Rule 3: If the class name contains "batch", "group", "layer" or "instance",
35
+ the abbreviation of this layer will be "bn", "gn", "ln" and "in"
36
+ respectively.
37
+ Rule 4: Otherwise, the abbreviation falls back to "norm".
38
+
39
+ Args:
40
+ class_type (type): The norm layer type.
41
+
42
+ Returns:
43
+ str: The inferred abbreviation.
44
+ """
45
+ if not inspect.isclass(class_type):
46
+ raise TypeError(
47
+ f'class_type must be a type, but got {type(class_type)}')
48
+ if hasattr(class_type, '_abbr_'):
49
+ return class_type._abbr_
50
+ if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
51
+ return 'in'
52
+ elif issubclass(class_type, _BatchNorm):
53
+ return 'bn'
54
+ elif issubclass(class_type, nn.GroupNorm):
55
+ return 'gn'
56
+ elif issubclass(class_type, nn.LayerNorm):
57
+ return 'ln'
58
+ else:
59
+ class_name = class_type.__name__.lower()
60
+ if 'batch' in class_name:
61
+ return 'bn'
62
+ elif 'group' in class_name:
63
+ return 'gn'
64
+ elif 'layer' in class_name:
65
+ return 'ln'
66
+ elif 'instance' in class_name:
67
+ return 'in'
68
+ else:
69
+ return 'norm_layer'
70
+
71
+
72
+ def build_norm_layer(cfg, num_features, postfix=''):
73
+ """Build normalization layer.
74
+
75
+ Args:
76
+ cfg (dict): The norm layer config, which should contain:
77
+
78
+ - type (str): Layer type.
79
+ - layer args: Args needed to instantiate a norm layer.
80
+ - requires_grad (bool, optional): Whether stop gradient updates.
81
+ num_features (int): Number of input channels.
82
+ postfix (int | str): The postfix to be appended into norm abbreviation
83
+ to create named layer.
84
+
85
+ Returns:
86
+ (str, nn.Module): The first element is the layer name consisting of
87
+ abbreviation and postfix, e.g., bn1, gn. The second element is the
88
+ created norm layer.
89
+ """
90
+ if not isinstance(cfg, dict):
91
+ raise TypeError('cfg must be a dict')
92
+ if 'type' not in cfg:
93
+ raise KeyError('the cfg dict must contain the key "type"')
94
+ cfg_ = cfg.copy()
95
+
96
+ layer_type = cfg_.pop('type')
97
+ if layer_type not in NORM_LAYERS:
98
+ raise KeyError(f'Unrecognized norm type {layer_type}')
99
+
100
+ norm_layer = NORM_LAYERS.get(layer_type)
101
+ abbr = infer_abbr(norm_layer)
102
+
103
+ assert isinstance(postfix, (int, str))
104
+ name = abbr + str(postfix)
105
+
106
+ requires_grad = cfg_.pop('requires_grad', True)
107
+ cfg_.setdefault('eps', 1e-5)
108
+ if layer_type != 'GN':
109
+ layer = norm_layer(num_features, **cfg_)
110
+ if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
111
+ layer._specify_ddp_gpu_num(1)
112
+ else:
113
+ assert 'num_groups' in cfg_
114
+ layer = norm_layer(num_channels=num_features, **cfg_)
115
+
116
+ for param in layer.parameters():
117
+ param.requires_grad = requires_grad
118
+
119
+ return name, layer
120
+
121
+
122
+ def is_norm(layer, exclude=None):
123
+ """Check if a layer is a normalization layer.
124
+
125
+ Args:
126
+ layer (nn.Module): The layer to be checked.
127
+ exclude (type | tuple[type]): Types to be excluded.
128
+
129
+ Returns:
130
+ bool: Whether the layer is a norm layer.
131
+ """
132
+ if exclude is not None:
133
+ if not isinstance(exclude, tuple):
134
+ exclude = (exclude, )
135
+ if not is_tuple_of(exclude, type):
136
+ raise TypeError(
137
+ f'"exclude" must be either None or type or a tuple of types, '
138
+ f'but got {type(exclude)}: {exclude}')
139
+
140
+ if exclude and isinstance(layer, exclude):
141
+ return False
142
+
143
+ all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
144
+ return isinstance(layer, all_norm_bases)
custom_mmpkg/custom_mmcv/cnn/bricks/padding.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+
4
+ from .registry import PADDING_LAYERS
5
+
6
+ PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
7
+ PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
8
+ PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
9
+
10
+
11
+ def build_padding_layer(cfg, *args, **kwargs):
12
+ """Build padding layer.
13
+
14
+ Args:
15
+ cfg (None or dict): The padding layer config, which should contain:
16
+ - type (str): Layer type.
17
+ - layer args: Args needed to instantiate a padding layer.
18
+
19
+ Returns:
20
+ nn.Module: Created padding layer.
21
+ """
22
+ if not isinstance(cfg, dict):
23
+ raise TypeError('cfg must be a dict')
24
+ if 'type' not in cfg:
25
+ raise KeyError('the cfg dict must contain the key "type"')
26
+
27
+ cfg_ = cfg.copy()
28
+ padding_type = cfg_.pop('type')
29
+ if padding_type not in PADDING_LAYERS:
30
+ raise KeyError(f'Unrecognized padding type {padding_type}.')
31
+ else:
32
+ padding_layer = PADDING_LAYERS.get(padding_type)
33
+
34
+ layer = padding_layer(*args, **kwargs, **cfg_)
35
+
36
+ return layer
custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import platform
3
+
4
+ from .registry import PLUGIN_LAYERS
5
+
6
+ if platform.system() == 'Windows':
7
+ import regex as re
8
+ else:
9
+ import re
10
+
11
+
12
+ def infer_abbr(class_type):
13
+ """Infer abbreviation from the class name.
14
+
15
+ This method will infer the abbreviation to map class types to
16
+ abbreviations.
17
+
18
+ Rule 1: If the class has the property "abbr", return the property.
19
+ Rule 2: Otherwise, the abbreviation falls back to snake case of class
20
+ name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
21
+
22
+ Args:
23
+ class_type (type): The norm layer type.
24
+
25
+ Returns:
26
+ str: The inferred abbreviation.
27
+ """
28
+
29
+ def camel2snack(word):
30
+ """Convert camel case word into snack case.
31
+
32
+ Modified from `inflection lib
33
+ <https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
34
+
35
+ Example::
36
+
37
+ >>> camel2snack("FancyBlock")
38
+ 'fancy_block'
39
+ """
40
+
41
+ word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
42
+ word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
43
+ word = word.replace('-', '_')
44
+ return word.lower()
45
+
46
+ if not inspect.isclass(class_type):
47
+ raise TypeError(
48
+ f'class_type must be a type, but got {type(class_type)}')
49
+ if hasattr(class_type, '_abbr_'):
50
+ return class_type._abbr_
51
+ else:
52
+ return camel2snack(class_type.__name__)
53
+
54
+
55
+ def build_plugin_layer(cfg, postfix='', **kwargs):
56
+ """Build plugin layer.
57
+
58
+ Args:
59
+ cfg (None or dict): cfg should contain:
60
+ type (str): identify plugin layer type.
61
+ layer args: args needed to instantiate a plugin layer.
62
+ postfix (int, str): appended into norm abbreviation to
63
+ create named layer. Default: ''.
64
+
65
+ Returns:
66
+ tuple[str, nn.Module]:
67
+ name (str): abbreviation + postfix
68
+ layer (nn.Module): created plugin layer
69
+ """
70
+ if not isinstance(cfg, dict):
71
+ raise TypeError('cfg must be a dict')
72
+ if 'type' not in cfg:
73
+ raise KeyError('the cfg dict must contain the key "type"')
74
+ cfg_ = cfg.copy()
75
+
76
+ layer_type = cfg_.pop('type')
77
+ if layer_type not in PLUGIN_LAYERS:
78
+ raise KeyError(f'Unrecognized plugin type {layer_type}')
79
+
80
+ plugin_layer = PLUGIN_LAYERS.get(layer_type)
81
+ abbr = infer_abbr(plugin_layer)
82
+
83
+ assert isinstance(postfix, (int, str))
84
+ name = abbr + str(postfix)
85
+
86
+ layer = plugin_layer(**kwargs, **cfg_)
87
+
88
+ return name, layer
custom_mmpkg/custom_mmcv/cnn/bricks/registry.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from custom_mmpkg.custom_mmcv.utils import Registry
3
+
4
+ CONV_LAYERS = Registry('conv layer')
5
+ NORM_LAYERS = Registry('norm layer')
6
+ ACTIVATION_LAYERS = Registry('activation layer')
7
+ PADDING_LAYERS = Registry('padding layer')
8
+ UPSAMPLE_LAYERS = Registry('upsample layer')
9
+ PLUGIN_LAYERS = Registry('plugin layer')
10
+
11
+ DROPOUT_LAYERS = Registry('drop out layers')
12
+ POSITIONAL_ENCODING = Registry('position encoding')
13
+ ATTENTION = Registry('attention')
14
+ FEEDFORWARD_NETWORK = Registry('feed-forward Network')
15
+ TRANSFORMER_LAYER = Registry('transformerLayer')
16
+ TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
custom_mmpkg/custom_mmcv/cnn/bricks/scale.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+
6
+ class Scale(nn.Module):
7
+ """A learnable scale parameter.
8
+
9
+ This layer scales the input by a learnable factor. It multiplies a
10
+ learnable scale parameter of shape (1,) with input of any shape.
11
+
12
+ Args:
13
+ scale (float): Initial value of scale factor. Default: 1.0
14
+ """
15
+
16
+ def __init__(self, scale=1.0):
17
+ super(Scale, self).__init__()
18
+ self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
19
+
20
+ def forward(self, x):
21
+ return x * self.scale
custom_mmpkg/custom_mmcv/cnn/bricks/swish.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from .registry import ACTIVATION_LAYERS
6
+
7
+
8
+ @ACTIVATION_LAYERS.register_module()
9
+ class Swish(nn.Module):
10
+ """Swish Module.
11
+
12
+ This module applies the swish function:
13
+
14
+ .. math::
15
+ Swish(x) = x * Sigmoid(x)
16
+
17
+ Returns:
18
+ Tensor: The output tensor.
19
+ """
20
+
21
+ def __init__(self):
22
+ super(Swish, self).__init__()
23
+
24
+ def forward(self, x):
25
+ return x * torch.sigmoid(x)
custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py ADDED
@@ -0,0 +1,595 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import copy
3
+ import warnings
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from custom_mmpkg.custom_mmcv import ConfigDict, deprecated_api_warning
9
+ from custom_mmpkg.custom_mmcv.cnn import Linear, build_activation_layer, build_norm_layer
10
+ from custom_mmpkg.custom_mmcv.runner.base_module import BaseModule, ModuleList, Sequential
11
+ from custom_mmpkg.custom_mmcv.utils import build_from_cfg
12
+ from .drop import build_dropout
13
+ from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
14
+ TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
15
+
16
+ # Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
17
+ try:
18
+ from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
19
+ warnings.warn(
20
+ ImportWarning(
21
+ '``MultiScaleDeformableAttention`` has been moved to '
22
+ '``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
23
+ '``from custom_mmpkg.custom_mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
24
+ 'to ``from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
25
+ ))
26
+
27
+ except ImportError:
28
+ warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
29
+ '``mmcv.ops.multi_scale_deform_attn``, '
30
+ 'You should install ``mmcv-full`` if you need this module. ')
31
+
32
+
33
+ def build_positional_encoding(cfg, default_args=None):
34
+ """Builder for Position Encoding."""
35
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
36
+
37
+
38
+ def build_attention(cfg, default_args=None):
39
+ """Builder for attention."""
40
+ return build_from_cfg(cfg, ATTENTION, default_args)
41
+
42
+
43
+ def build_feedforward_network(cfg, default_args=None):
44
+ """Builder for feed-forward network (FFN)."""
45
+ return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
46
+
47
+
48
+ def build_transformer_layer(cfg, default_args=None):
49
+ """Builder for transformer layer."""
50
+ return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
51
+
52
+
53
+ def build_transformer_layer_sequence(cfg, default_args=None):
54
+ """Builder for transformer encoder and transformer decoder."""
55
+ return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
56
+
57
+
58
+ @ATTENTION.register_module()
59
+ class MultiheadAttention(BaseModule):
60
+ """A wrapper for ``torch.nn.MultiheadAttention``.
61
+
62
+ This module implements MultiheadAttention with identity connection,
63
+ and positional encoding is also passed as input.
64
+
65
+ Args:
66
+ embed_dims (int): The embedding dimension.
67
+ num_heads (int): Parallel attention heads.
68
+ attn_drop (float): A Dropout layer on attn_output_weights.
69
+ Default: 0.0.
70
+ proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
71
+ Default: 0.0.
72
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
73
+ when adding the shortcut.
74
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
75
+ Default: None.
76
+ batch_first (bool): When it is True, Key, Query and Value are shape of
77
+ (batch, n, embed_dim), otherwise (n, batch, embed_dim).
78
+ Default to False.
79
+ """
80
+
81
+ def __init__(self,
82
+ embed_dims,
83
+ num_heads,
84
+ attn_drop=0.,
85
+ proj_drop=0.,
86
+ dropout_layer=dict(type='Dropout', drop_prob=0.),
87
+ init_cfg=None,
88
+ batch_first=False,
89
+ **kwargs):
90
+ super(MultiheadAttention, self).__init__(init_cfg)
91
+ if 'dropout' in kwargs:
92
+ warnings.warn('The arguments `dropout` in MultiheadAttention '
93
+ 'has been deprecated, now you can separately '
94
+ 'set `attn_drop`(float), proj_drop(float), '
95
+ 'and `dropout_layer`(dict) ')
96
+ attn_drop = kwargs['dropout']
97
+ dropout_layer['drop_prob'] = kwargs.pop('dropout')
98
+
99
+ self.embed_dims = embed_dims
100
+ self.num_heads = num_heads
101
+ self.batch_first = batch_first
102
+
103
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
104
+ **kwargs)
105
+
106
+ self.proj_drop = nn.Dropout(proj_drop)
107
+ self.dropout_layer = build_dropout(
108
+ dropout_layer) if dropout_layer else nn.Identity()
109
+
110
+ @deprecated_api_warning({'residual': 'identity'},
111
+ cls_name='MultiheadAttention')
112
+ def forward(self,
113
+ query,
114
+ key=None,
115
+ value=None,
116
+ identity=None,
117
+ query_pos=None,
118
+ key_pos=None,
119
+ attn_mask=None,
120
+ key_padding_mask=None,
121
+ **kwargs):
122
+ """Forward function for `MultiheadAttention`.
123
+
124
+ **kwargs allow passing a more general data flow when combining
125
+ with other operations in `transformerlayer`.
126
+
127
+ Args:
128
+ query (Tensor): The input query with shape [num_queries, bs,
129
+ embed_dims] if self.batch_first is False, else
130
+ [bs, num_queries embed_dims].
131
+ key (Tensor): The key tensor with shape [num_keys, bs,
132
+ embed_dims] if self.batch_first is False, else
133
+ [bs, num_keys, embed_dims] .
134
+ If None, the ``query`` will be used. Defaults to None.
135
+ value (Tensor): The value tensor with same shape as `key`.
136
+ Same in `nn.MultiheadAttention.forward`. Defaults to None.
137
+ If None, the `key` will be used.
138
+ identity (Tensor): This tensor, with the same shape as x,
139
+ will be used for the identity link.
140
+ If None, `x` will be used. Defaults to None.
141
+ query_pos (Tensor): The positional encoding for query, with
142
+ the same shape as `x`. If not None, it will
143
+ be added to `x` before forward function. Defaults to None.
144
+ key_pos (Tensor): The positional encoding for `key`, with the
145
+ same shape as `key`. Defaults to None. If not None, it will
146
+ be added to `key` before forward function. If None, and
147
+ `query_pos` has the same shape as `key`, then `query_pos`
148
+ will be used for `key_pos`. Defaults to None.
149
+ attn_mask (Tensor): ByteTensor mask with shape [num_queries,
150
+ num_keys]. Same in `nn.MultiheadAttention.forward`.
151
+ Defaults to None.
152
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
153
+ Defaults to None.
154
+
155
+ Returns:
156
+ Tensor: forwarded results with shape
157
+ [num_queries, bs, embed_dims]
158
+ if self.batch_first is False, else
159
+ [bs, num_queries embed_dims].
160
+ """
161
+
162
+ if key is None:
163
+ key = query
164
+ if value is None:
165
+ value = key
166
+ if identity is None:
167
+ identity = query
168
+ if key_pos is None:
169
+ if query_pos is not None:
170
+ # use query_pos if key_pos is not available
171
+ if query_pos.shape == key.shape:
172
+ key_pos = query_pos
173
+ else:
174
+ warnings.warn(f'position encoding of key is'
175
+ f'missing in {self.__class__.__name__}.')
176
+ if query_pos is not None:
177
+ query = query + query_pos
178
+ if key_pos is not None:
179
+ key = key + key_pos
180
+
181
+ # Because the dataflow('key', 'query', 'value') of
182
+ # ``torch.nn.MultiheadAttention`` is (num_query, batch,
183
+ # embed_dims), We should adjust the shape of dataflow from
184
+ # batch_first (batch, num_query, embed_dims) to num_query_first
185
+ # (num_query ,batch, embed_dims), and recover ``attn_output``
186
+ # from num_query_first to batch_first.
187
+ if self.batch_first:
188
+ query = query.transpose(0, 1)
189
+ key = key.transpose(0, 1)
190
+ value = value.transpose(0, 1)
191
+
192
+ out = self.attn(
193
+ query=query,
194
+ key=key,
195
+ value=value,
196
+ attn_mask=attn_mask,
197
+ key_padding_mask=key_padding_mask)[0]
198
+
199
+ if self.batch_first:
200
+ out = out.transpose(0, 1)
201
+
202
+ return identity + self.dropout_layer(self.proj_drop(out))
203
+
204
+
205
+ @FEEDFORWARD_NETWORK.register_module()
206
+ class FFN(BaseModule):
207
+ """Implements feed-forward networks (FFNs) with identity connection.
208
+
209
+ Args:
210
+ embed_dims (int): The feature dimension. Same as
211
+ `MultiheadAttention`. Defaults: 256.
212
+ feedforward_channels (int): The hidden dimension of FFNs.
213
+ Defaults: 1024.
214
+ num_fcs (int, optional): The number of fully-connected layers in
215
+ FFNs. Default: 2.
216
+ act_cfg (dict, optional): The activation config for FFNs.
217
+ Default: dict(type='ReLU')
218
+ ffn_drop (float, optional): Probability of an element to be
219
+ zeroed in FFN. Default 0.0.
220
+ add_identity (bool, optional): Whether to add the
221
+ identity connection. Default: `True`.
222
+ dropout_layer (obj:`ConfigDict`): The dropout_layer used
223
+ when adding the shortcut.
224
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
225
+ Default: None.
226
+ """
227
+
228
+ @deprecated_api_warning(
229
+ {
230
+ 'dropout': 'ffn_drop',
231
+ 'add_residual': 'add_identity'
232
+ },
233
+ cls_name='FFN')
234
+ def __init__(self,
235
+ embed_dims=256,
236
+ feedforward_channels=1024,
237
+ num_fcs=2,
238
+ act_cfg=dict(type='ReLU', inplace=True),
239
+ ffn_drop=0.,
240
+ dropout_layer=None,
241
+ add_identity=True,
242
+ init_cfg=None,
243
+ **kwargs):
244
+ super(FFN, self).__init__(init_cfg)
245
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
246
+ f'than 2. got {num_fcs}.'
247
+ self.embed_dims = embed_dims
248
+ self.feedforward_channels = feedforward_channels
249
+ self.num_fcs = num_fcs
250
+ self.act_cfg = act_cfg
251
+ self.activate = build_activation_layer(act_cfg)
252
+
253
+ layers = []
254
+ in_channels = embed_dims
255
+ for _ in range(num_fcs - 1):
256
+ layers.append(
257
+ Sequential(
258
+ Linear(in_channels, feedforward_channels), self.activate,
259
+ nn.Dropout(ffn_drop)))
260
+ in_channels = feedforward_channels
261
+ layers.append(Linear(feedforward_channels, embed_dims))
262
+ layers.append(nn.Dropout(ffn_drop))
263
+ self.layers = Sequential(*layers)
264
+ self.dropout_layer = build_dropout(
265
+ dropout_layer) if dropout_layer else torch.nn.Identity()
266
+ self.add_identity = add_identity
267
+
268
+ @deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
269
+ def forward(self, x, identity=None):
270
+ """Forward function for `FFN`.
271
+
272
+ The function would add x to the output tensor if residue is None.
273
+ """
274
+ out = self.layers(x)
275
+ if not self.add_identity:
276
+ return self.dropout_layer(out)
277
+ if identity is None:
278
+ identity = x
279
+ return identity + self.dropout_layer(out)
280
+
281
+
282
+ @TRANSFORMER_LAYER.register_module()
283
+ class BaseTransformerLayer(BaseModule):
284
+ """Base `TransformerLayer` for vision transformer.
285
+
286
+ It can be built from `mmcv.ConfigDict` and support more flexible
287
+ customization, for example, using any number of `FFN or LN ` and
288
+ use different kinds of `attention` by specifying a list of `ConfigDict`
289
+ named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
290
+ when you specifying `norm` as the first element of `operation_order`.
291
+ More details about the `prenorm`: `On Layer Normalization in the
292
+ Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
293
+
294
+ Args:
295
+ attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
296
+ Configs for `self_attention` or `cross_attention` modules,
297
+ The order of the configs in the list should be consistent with
298
+ corresponding attentions in operation_order.
299
+ If it is a dict, all of the attention modules in operation_order
300
+ will be built with this config. Default: None.
301
+ ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
302
+ Configs for FFN, The order of the configs in the list should be
303
+ consistent with corresponding ffn in operation_order.
304
+ If it is a dict, all of the attention modules in operation_order
305
+ will be built with this config.
306
+ operation_order (tuple[str]): The execution order of operation
307
+ in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
308
+ Support `prenorm` when you specifying first element as `norm`.
309
+ Default:None.
310
+ norm_cfg (dict): Config dict for normalization layer.
311
+ Default: dict(type='LN').
312
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
313
+ Default: None.
314
+ batch_first (bool): Key, Query and Value are shape
315
+ of (batch, n, embed_dim)
316
+ or (n, batch, embed_dim). Default to False.
317
+ """
318
+
319
+ def __init__(self,
320
+ attn_cfgs=None,
321
+ ffn_cfgs=dict(
322
+ type='FFN',
323
+ embed_dims=256,
324
+ feedforward_channels=1024,
325
+ num_fcs=2,
326
+ ffn_drop=0.,
327
+ act_cfg=dict(type='ReLU', inplace=True),
328
+ ),
329
+ operation_order=None,
330
+ norm_cfg=dict(type='LN'),
331
+ init_cfg=None,
332
+ batch_first=False,
333
+ **kwargs):
334
+
335
+ deprecated_args = dict(
336
+ feedforward_channels='feedforward_channels',
337
+ ffn_dropout='ffn_drop',
338
+ ffn_num_fcs='num_fcs')
339
+ for ori_name, new_name in deprecated_args.items():
340
+ if ori_name in kwargs:
341
+ warnings.warn(
342
+ f'The arguments `{ori_name}` in BaseTransformerLayer '
343
+ f'has been deprecated, now you should set `{new_name}` '
344
+ f'and other FFN related arguments '
345
+ f'to a dict named `ffn_cfgs`. ')
346
+ ffn_cfgs[new_name] = kwargs[ori_name]
347
+
348
+ super(BaseTransformerLayer, self).__init__(init_cfg)
349
+
350
+ self.batch_first = batch_first
351
+
352
+ assert set(operation_order) & set(
353
+ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
354
+ set(operation_order), f'The operation_order of' \
355
+ f' {self.__class__.__name__} should ' \
356
+ f'contains all four operation type ' \
357
+ f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
358
+
359
+ num_attn = operation_order.count('self_attn') + operation_order.count(
360
+ 'cross_attn')
361
+ if isinstance(attn_cfgs, dict):
362
+ attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
363
+ else:
364
+ assert num_attn == len(attn_cfgs), f'The length ' \
365
+ f'of attn_cfg {num_attn} is ' \
366
+ f'not consistent with the number of attention' \
367
+ f'in operation_order {operation_order}.'
368
+
369
+ self.num_attn = num_attn
370
+ self.operation_order = operation_order
371
+ self.norm_cfg = norm_cfg
372
+ self.pre_norm = operation_order[0] == 'norm'
373
+ self.attentions = ModuleList()
374
+
375
+ index = 0
376
+ for operation_name in operation_order:
377
+ if operation_name in ['self_attn', 'cross_attn']:
378
+ if 'batch_first' in attn_cfgs[index]:
379
+ assert self.batch_first == attn_cfgs[index]['batch_first']
380
+ else:
381
+ attn_cfgs[index]['batch_first'] = self.batch_first
382
+ attention = build_attention(attn_cfgs[index])
383
+ # Some custom attentions used as `self_attn`
384
+ # or `cross_attn` can have different behavior.
385
+ attention.operation_name = operation_name
386
+ self.attentions.append(attention)
387
+ index += 1
388
+
389
+ self.embed_dims = self.attentions[0].embed_dims
390
+
391
+ self.ffns = ModuleList()
392
+ num_ffns = operation_order.count('ffn')
393
+ if isinstance(ffn_cfgs, dict):
394
+ ffn_cfgs = ConfigDict(ffn_cfgs)
395
+ if isinstance(ffn_cfgs, dict):
396
+ ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
397
+ assert len(ffn_cfgs) == num_ffns
398
+ for ffn_index in range(num_ffns):
399
+ if 'embed_dims' not in ffn_cfgs[ffn_index]:
400
+ ffn_cfgs['embed_dims'] = self.embed_dims
401
+ else:
402
+ assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
403
+ self.ffns.append(
404
+ build_feedforward_network(ffn_cfgs[ffn_index],
405
+ dict(type='FFN')))
406
+
407
+ self.norms = ModuleList()
408
+ num_norms = operation_order.count('norm')
409
+ for _ in range(num_norms):
410
+ self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
411
+
412
+ def forward(self,
413
+ query,
414
+ key=None,
415
+ value=None,
416
+ query_pos=None,
417
+ key_pos=None,
418
+ attn_masks=None,
419
+ query_key_padding_mask=None,
420
+ key_padding_mask=None,
421
+ **kwargs):
422
+ """Forward function for `TransformerDecoderLayer`.
423
+
424
+ **kwargs contains some specific arguments of attentions.
425
+
426
+ Args:
427
+ query (Tensor): The input query with shape
428
+ [num_queries, bs, embed_dims] if
429
+ self.batch_first is False, else
430
+ [bs, num_queries embed_dims].
431
+ key (Tensor): The key tensor with shape [num_keys, bs,
432
+ embed_dims] if self.batch_first is False, else
433
+ [bs, num_keys, embed_dims] .
434
+ value (Tensor): The value tensor with same shape as `key`.
435
+ query_pos (Tensor): The positional encoding for `query`.
436
+ Default: None.
437
+ key_pos (Tensor): The positional encoding for `key`.
438
+ Default: None.
439
+ attn_masks (List[Tensor] | None): 2D Tensor used in
440
+ calculation of corresponding attention. The length of
441
+ it should equal to the number of `attention` in
442
+ `operation_order`. Default: None.
443
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
444
+ shape [bs, num_queries]. Only used in `self_attn` layer.
445
+ Defaults to None.
446
+ key_padding_mask (Tensor): ByteTensor for `query`, with
447
+ shape [bs, num_keys]. Default: None.
448
+
449
+ Returns:
450
+ Tensor: forwarded results with shape [num_queries, bs, embed_dims].
451
+ """
452
+
453
+ norm_index = 0
454
+ attn_index = 0
455
+ ffn_index = 0
456
+ identity = query
457
+ if attn_masks is None:
458
+ attn_masks = [None for _ in range(self.num_attn)]
459
+ elif isinstance(attn_masks, torch.Tensor):
460
+ attn_masks = [
461
+ copy.deepcopy(attn_masks) for _ in range(self.num_attn)
462
+ ]
463
+ warnings.warn(f'Use same attn_mask in all attentions in '
464
+ f'{self.__class__.__name__} ')
465
+ else:
466
+ assert len(attn_masks) == self.num_attn, f'The length of ' \
467
+ f'attn_masks {len(attn_masks)} must be equal ' \
468
+ f'to the number of attention in ' \
469
+ f'operation_order {self.num_attn}'
470
+
471
+ for layer in self.operation_order:
472
+ if layer == 'self_attn':
473
+ temp_key = temp_value = query
474
+ query = self.attentions[attn_index](
475
+ query,
476
+ temp_key,
477
+ temp_value,
478
+ identity if self.pre_norm else None,
479
+ query_pos=query_pos,
480
+ key_pos=query_pos,
481
+ attn_mask=attn_masks[attn_index],
482
+ key_padding_mask=query_key_padding_mask,
483
+ **kwargs)
484
+ attn_index += 1
485
+ identity = query
486
+
487
+ elif layer == 'norm':
488
+ query = self.norms[norm_index](query)
489
+ norm_index += 1
490
+
491
+ elif layer == 'cross_attn':
492
+ query = self.attentions[attn_index](
493
+ query,
494
+ key,
495
+ value,
496
+ identity if self.pre_norm else None,
497
+ query_pos=query_pos,
498
+ key_pos=key_pos,
499
+ attn_mask=attn_masks[attn_index],
500
+ key_padding_mask=key_padding_mask,
501
+ **kwargs)
502
+ attn_index += 1
503
+ identity = query
504
+
505
+ elif layer == 'ffn':
506
+ query = self.ffns[ffn_index](
507
+ query, identity if self.pre_norm else None)
508
+ ffn_index += 1
509
+
510
+ return query
511
+
512
+
513
+ @TRANSFORMER_LAYER_SEQUENCE.register_module()
514
+ class TransformerLayerSequence(BaseModule):
515
+ """Base class for TransformerEncoder and TransformerDecoder in vision
516
+ transformer.
517
+
518
+ As base-class of Encoder and Decoder in vision transformer.
519
+ Support customization such as specifying different kind
520
+ of `transformer_layer` in `transformer_coder`.
521
+
522
+ Args:
523
+ transformerlayer (list[obj:`mmcv.ConfigDict`] |
524
+ obj:`mmcv.ConfigDict`): Config of transformerlayer
525
+ in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
526
+ it would be repeated `num_layer` times to a
527
+ list[`mmcv.ConfigDict`]. Default: None.
528
+ num_layers (int): The number of `TransformerLayer`. Default: None.
529
+ init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
530
+ Default: None.
531
+ """
532
+
533
+ def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
534
+ super(TransformerLayerSequence, self).__init__(init_cfg)
535
+ if isinstance(transformerlayers, dict):
536
+ transformerlayers = [
537
+ copy.deepcopy(transformerlayers) for _ in range(num_layers)
538
+ ]
539
+ else:
540
+ assert isinstance(transformerlayers, list) and \
541
+ len(transformerlayers) == num_layers
542
+ self.num_layers = num_layers
543
+ self.layers = ModuleList()
544
+ for i in range(num_layers):
545
+ self.layers.append(build_transformer_layer(transformerlayers[i]))
546
+ self.embed_dims = self.layers[0].embed_dims
547
+ self.pre_norm = self.layers[0].pre_norm
548
+
549
+ def forward(self,
550
+ query,
551
+ key,
552
+ value,
553
+ query_pos=None,
554
+ key_pos=None,
555
+ attn_masks=None,
556
+ query_key_padding_mask=None,
557
+ key_padding_mask=None,
558
+ **kwargs):
559
+ """Forward function for `TransformerCoder`.
560
+
561
+ Args:
562
+ query (Tensor): Input query with shape
563
+ `(num_queries, bs, embed_dims)`.
564
+ key (Tensor): The key tensor with shape
565
+ `(num_keys, bs, embed_dims)`.
566
+ value (Tensor): The value tensor with shape
567
+ `(num_keys, bs, embed_dims)`.
568
+ query_pos (Tensor): The positional encoding for `query`.
569
+ Default: None.
570
+ key_pos (Tensor): The positional encoding for `key`.
571
+ Default: None.
572
+ attn_masks (List[Tensor], optional): Each element is 2D Tensor
573
+ which is used in calculation of corresponding attention in
574
+ operation_order. Default: None.
575
+ query_key_padding_mask (Tensor): ByteTensor for `query`, with
576
+ shape [bs, num_queries]. Only used in self-attention
577
+ Default: None.
578
+ key_padding_mask (Tensor): ByteTensor for `query`, with
579
+ shape [bs, num_keys]. Default: None.
580
+
581
+ Returns:
582
+ Tensor: results with shape [num_queries, bs, embed_dims].
583
+ """
584
+ for layer in self.layers:
585
+ query = layer(
586
+ query,
587
+ key,
588
+ value,
589
+ query_pos=query_pos,
590
+ key_pos=key_pos,
591
+ attn_masks=attn_masks,
592
+ query_key_padding_mask=query_key_padding_mask,
593
+ key_padding_mask=key_padding_mask,
594
+ **kwargs)
595
+ return query
custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from ..utils import xavier_init
6
+ from .registry import UPSAMPLE_LAYERS
7
+
8
+ UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
9
+ UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
10
+
11
+
12
+ @UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
13
+ class PixelShufflePack(nn.Module):
14
+ """Pixel Shuffle upsample layer.
15
+
16
+ This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
17
+ achieve a simple upsampling with pixel shuffle.
18
+
19
+ Args:
20
+ in_channels (int): Number of input channels.
21
+ out_channels (int): Number of output channels.
22
+ scale_factor (int): Upsample ratio.
23
+ upsample_kernel (int): Kernel size of the conv layer to expand the
24
+ channels.
25
+ """
26
+
27
+ def __init__(self, in_channels, out_channels, scale_factor,
28
+ upsample_kernel):
29
+ super(PixelShufflePack, self).__init__()
30
+ self.in_channels = in_channels
31
+ self.out_channels = out_channels
32
+ self.scale_factor = scale_factor
33
+ self.upsample_kernel = upsample_kernel
34
+ self.upsample_conv = nn.Conv2d(
35
+ self.in_channels,
36
+ self.out_channels * scale_factor * scale_factor,
37
+ self.upsample_kernel,
38
+ padding=(self.upsample_kernel - 1) // 2)
39
+ self.init_weights()
40
+
41
+ def init_weights(self):
42
+ xavier_init(self.upsample_conv, distribution='uniform')
43
+
44
+ def forward(self, x):
45
+ x = self.upsample_conv(x)
46
+ x = F.pixel_shuffle(x, self.scale_factor)
47
+ return x
48
+
49
+
50
+ def build_upsample_layer(cfg, *args, **kwargs):
51
+ """Build upsample layer.
52
+
53
+ Args:
54
+ cfg (dict): The upsample layer config, which should contain:
55
+
56
+ - type (str): Layer type.
57
+ - scale_factor (int): Upsample ratio, which is not applicable to
58
+ deconv.
59
+ - layer args: Args needed to instantiate a upsample layer.
60
+ args (argument list): Arguments passed to the ``__init__``
61
+ method of the corresponding conv layer.
62
+ kwargs (keyword arguments): Keyword arguments passed to the
63
+ ``__init__`` method of the corresponding conv layer.
64
+
65
+ Returns:
66
+ nn.Module: Created upsample layer.
67
+ """
68
+ if not isinstance(cfg, dict):
69
+ raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
70
+ if 'type' not in cfg:
71
+ raise KeyError(
72
+ f'the cfg dict must contain the key "type", but got {cfg}')
73
+ cfg_ = cfg.copy()
74
+
75
+ layer_type = cfg_.pop('type')
76
+ if layer_type not in UPSAMPLE_LAYERS:
77
+ raise KeyError(f'Unrecognized upsample type {layer_type}')
78
+ else:
79
+ upsample = UPSAMPLE_LAYERS.get(layer_type)
80
+
81
+ if upsample is nn.Upsample:
82
+ cfg_['mode'] = layer_type
83
+ layer = upsample(*args, **kwargs, **cfg_)
84
+ return layer
custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
3
+
4
+ Wrap some nn modules to support empty tensor input. Currently, these wrappers
5
+ are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
6
+ heads are trained on only positive RoIs.
7
+ """
8
+ import math
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn.modules.utils import _pair, _triple
13
+
14
+ from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
15
+
16
+ if torch.__version__ == 'parrots':
17
+ TORCH_VERSION = torch.__version__
18
+ else:
19
+ # torch.__version__ could be 1.3.1+cu92, we only need the first two
20
+ # for comparison
21
+ TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
22
+
23
+
24
+ def obsolete_torch_version(torch_version, version_threshold):
25
+ return torch_version == 'parrots' or torch_version <= version_threshold
26
+
27
+
28
+ class NewEmptyTensorOp(torch.autograd.Function):
29
+
30
+ @staticmethod
31
+ def forward(ctx, x, new_shape):
32
+ ctx.shape = x.shape
33
+ return x.new_empty(new_shape)
34
+
35
+ @staticmethod
36
+ def backward(ctx, grad):
37
+ shape = ctx.shape
38
+ return NewEmptyTensorOp.apply(grad, shape), None
39
+
40
+
41
+ @CONV_LAYERS.register_module('Conv', force=True)
42
+ class Conv2d(nn.Conv2d):
43
+
44
+ def forward(self, x):
45
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
46
+ out_shape = [x.shape[0], self.out_channels]
47
+ for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
48
+ self.padding, self.stride, self.dilation):
49
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
50
+ out_shape.append(o)
51
+ empty = NewEmptyTensorOp.apply(x, out_shape)
52
+ if self.training:
53
+ # produce dummy gradient to avoid DDP warning.
54
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
55
+ return empty + dummy
56
+ else:
57
+ return empty
58
+
59
+ return super().forward(x)
60
+
61
+
62
+ @CONV_LAYERS.register_module('Conv3d', force=True)
63
+ class Conv3d(nn.Conv3d):
64
+
65
+ def forward(self, x):
66
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
67
+ out_shape = [x.shape[0], self.out_channels]
68
+ for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
69
+ self.padding, self.stride, self.dilation):
70
+ o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
71
+ out_shape.append(o)
72
+ empty = NewEmptyTensorOp.apply(x, out_shape)
73
+ if self.training:
74
+ # produce dummy gradient to avoid DDP warning.
75
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
76
+ return empty + dummy
77
+ else:
78
+ return empty
79
+
80
+ return super().forward(x)
81
+
82
+
83
+ @CONV_LAYERS.register_module()
84
+ @CONV_LAYERS.register_module('deconv')
85
+ @UPSAMPLE_LAYERS.register_module('deconv', force=True)
86
+ class ConvTranspose2d(nn.ConvTranspose2d):
87
+
88
+ def forward(self, x):
89
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
90
+ out_shape = [x.shape[0], self.out_channels]
91
+ for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
92
+ self.padding, self.stride,
93
+ self.dilation, self.output_padding):
94
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
95
+ empty = NewEmptyTensorOp.apply(x, out_shape)
96
+ if self.training:
97
+ # produce dummy gradient to avoid DDP warning.
98
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
99
+ return empty + dummy
100
+ else:
101
+ return empty
102
+
103
+ return super().forward(x)
104
+
105
+
106
+ @CONV_LAYERS.register_module()
107
+ @CONV_LAYERS.register_module('deconv3d')
108
+ @UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
109
+ class ConvTranspose3d(nn.ConvTranspose3d):
110
+
111
+ def forward(self, x):
112
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
113
+ out_shape = [x.shape[0], self.out_channels]
114
+ for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
115
+ self.padding, self.stride,
116
+ self.dilation, self.output_padding):
117
+ out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
118
+ empty = NewEmptyTensorOp.apply(x, out_shape)
119
+ if self.training:
120
+ # produce dummy gradient to avoid DDP warning.
121
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
122
+ return empty + dummy
123
+ else:
124
+ return empty
125
+
126
+ return super().forward(x)
127
+
128
+
129
+ class MaxPool2d(nn.MaxPool2d):
130
+
131
+ def forward(self, x):
132
+ # PyTorch 1.9 does not support empty tensor inference yet
133
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
134
+ out_shape = list(x.shape[:2])
135
+ for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
136
+ _pair(self.padding), _pair(self.stride),
137
+ _pair(self.dilation)):
138
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
139
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
140
+ out_shape.append(o)
141
+ empty = NewEmptyTensorOp.apply(x, out_shape)
142
+ return empty
143
+
144
+ return super().forward(x)
145
+
146
+
147
+ class MaxPool3d(nn.MaxPool3d):
148
+
149
+ def forward(self, x):
150
+ # PyTorch 1.9 does not support empty tensor inference yet
151
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
152
+ out_shape = list(x.shape[:2])
153
+ for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
154
+ _triple(self.padding),
155
+ _triple(self.stride),
156
+ _triple(self.dilation)):
157
+ o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
158
+ o = math.ceil(o) if self.ceil_mode else math.floor(o)
159
+ out_shape.append(o)
160
+ empty = NewEmptyTensorOp.apply(x, out_shape)
161
+ return empty
162
+
163
+ return super().forward(x)
164
+
165
+
166
+ class Linear(torch.nn.Linear):
167
+
168
+ def forward(self, x):
169
+ # empty tensor forward of Linear layer is supported in Pytorch 1.6
170
+ if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
171
+ out_shape = [x.shape[0], self.out_features]
172
+ empty = NewEmptyTensorOp.apply(x, out_shape)
173
+ if self.training:
174
+ # produce dummy gradient to avoid DDP warning.
175
+ dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
176
+ return empty + dummy
177
+ else:
178
+ return empty
179
+
180
+ return super().forward(x)
custom_mmpkg/custom_mmcv/cnn/builder.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) OpenMMLab. All rights reserved.
2
+ from ..runner import Sequential
3
+ from ..utils import Registry, build_from_cfg
4
+
5
+
6
+ def build_model_from_cfg(cfg, registry, default_args=None):
7
+ """Build a PyTorch model from config dict(s). Different from
8
+ ``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
9
+
10
+ Args:
11
+ cfg (dict, list[dict]): The config of modules, is is either a config
12
+ dict or a list of config dicts. If cfg is a list, a
13
+ the built modules will be wrapped with ``nn.Sequential``.
14
+ registry (:obj:`Registry`): A registry the module belongs to.
15
+ default_args (dict, optional): Default arguments to build the module.
16
+ Defaults to None.
17
+
18
+ Returns:
19
+ nn.Module: A built nn module.
20
+ """
21
+ if isinstance(cfg, list):
22
+ modules = [
23
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
24
+ ]
25
+ return Sequential(*modules)
26
+ else:
27
+ return build_from_cfg(cfg, registry, default_args)
28
+
29
+
30
+ MODELS = Registry('model', build_func=build_model_from_cfg)