Spaces:
Configuration error
Configuration error
Upload 564 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- custom_midas_repo/LICENSE +21 -0
- custom_midas_repo/README.md +259 -0
- custom_midas_repo/__init__.py +0 -0
- custom_midas_repo/hubconf.py +435 -0
- custom_midas_repo/midas/__init__.py +0 -0
- custom_midas_repo/midas/backbones/__init__.py +0 -0
- custom_midas_repo/midas/backbones/beit.py +196 -0
- custom_midas_repo/midas/backbones/levit.py +106 -0
- custom_midas_repo/midas/backbones/next_vit.py +39 -0
- custom_midas_repo/midas/backbones/swin.py +13 -0
- custom_midas_repo/midas/backbones/swin2.py +34 -0
- custom_midas_repo/midas/backbones/swin_common.py +52 -0
- custom_midas_repo/midas/backbones/utils.py +249 -0
- custom_midas_repo/midas/backbones/vit.py +221 -0
- custom_midas_repo/midas/base_model.py +16 -0
- custom_midas_repo/midas/blocks.py +439 -0
- custom_midas_repo/midas/dpt_depth.py +166 -0
- custom_midas_repo/midas/midas_net.py +76 -0
- custom_midas_repo/midas/midas_net_custom.py +128 -0
- custom_midas_repo/midas/model_loader.py +242 -0
- custom_midas_repo/midas/transforms.py +234 -0
- custom_mmpkg/__init__.py +1 -0
- custom_mmpkg/custom_mmcv/__init__.py +15 -0
- custom_mmpkg/custom_mmcv/arraymisc/__init__.py +4 -0
- custom_mmpkg/custom_mmcv/arraymisc/quantization.py +55 -0
- custom_mmpkg/custom_mmcv/cnn/__init__.py +41 -0
- custom_mmpkg/custom_mmcv/cnn/alexnet.py +61 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py +35 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/activation.py +92 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py +125 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/conv.py +44 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py +62 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py +206 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py +148 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py +96 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/drop.py +65 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py +412 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py +34 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py +29 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py +306 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/norm.py +144 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/padding.py +36 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py +88 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/registry.py +16 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/scale.py +21 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/swish.py +25 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py +595 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py +84 -0
- custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py +180 -0
- custom_mmpkg/custom_mmcv/cnn/builder.py +30 -0
custom_midas_repo/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
custom_midas_repo/README.md
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
|
2 |
+
|
3 |
+
This repository contains code to compute depth from a single image. It accompanies our [paper](https://arxiv.org/abs/1907.01341v3):
|
4 |
+
|
5 |
+
>Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer
|
6 |
+
René Ranftl, Katrin Lasinger, David Hafner, Konrad Schindler, Vladlen Koltun
|
7 |
+
|
8 |
+
|
9 |
+
and our [preprint](https://arxiv.org/abs/2103.13413):
|
10 |
+
|
11 |
+
> Vision Transformers for Dense Prediction
|
12 |
+
> René Ranftl, Alexey Bochkovskiy, Vladlen Koltun
|
13 |
+
|
14 |
+
|
15 |
+
MiDaS was trained on up to 12 datasets (ReDWeb, DIML, Movies, MegaDepth, WSVD, TartanAir, HRWSI, ApolloScape, BlendedMVS, IRS, KITTI, NYU Depth V2) with
|
16 |
+
multi-objective optimization.
|
17 |
+
The original model that was trained on 5 datasets (`MIX 5` in the paper) can be found [here](https://github.com/isl-org/MiDaS/releases/tag/v2).
|
18 |
+
The figure below shows an overview of the different MiDaS models; the bubble size scales with number of parameters.
|
19 |
+
|
20 |
+

|
21 |
+
|
22 |
+
### Setup
|
23 |
+
|
24 |
+
1) Pick one or more models and download the corresponding weights to the `weights` folder:
|
25 |
+
|
26 |
+
MiDaS 3.1
|
27 |
+
- For highest quality: [dpt_beit_large_512](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)
|
28 |
+
- For moderately less quality, but better speed-performance trade-off: [dpt_swin2_large_384](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)
|
29 |
+
- For embedded devices: [dpt_swin2_tiny_256](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt), [dpt_levit_224](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)
|
30 |
+
- For inference on Intel CPUs, OpenVINO may be used for the small legacy model: openvino_midas_v21_small [.xml](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.xml), [.bin](https://github.com/isl-org/MiDaS/releases/download/v3_1/openvino_midas_v21_small_256.bin)
|
31 |
+
|
32 |
+
MiDaS 3.0: Legacy transformer models [dpt_large_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) and [dpt_hybrid_384](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt)
|
33 |
+
|
34 |
+
MiDaS 2.1: Legacy convolutional models [midas_v21_384](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) and [midas_v21_small_256](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt)
|
35 |
+
|
36 |
+
1) Set up dependencies:
|
37 |
+
|
38 |
+
```shell
|
39 |
+
conda env create -f environment.yaml
|
40 |
+
conda activate midas-py310
|
41 |
+
```
|
42 |
+
|
43 |
+
#### optional
|
44 |
+
|
45 |
+
For the Next-ViT model, execute
|
46 |
+
|
47 |
+
```shell
|
48 |
+
git submodule add https://github.com/isl-org/Next-ViT midas/external/next_vit
|
49 |
+
```
|
50 |
+
|
51 |
+
For the OpenVINO model, install
|
52 |
+
|
53 |
+
```shell
|
54 |
+
pip install openvino
|
55 |
+
```
|
56 |
+
|
57 |
+
### Usage
|
58 |
+
|
59 |
+
1) Place one or more input images in the folder `input`.
|
60 |
+
|
61 |
+
2) Run the model with
|
62 |
+
|
63 |
+
```shell
|
64 |
+
python run.py --model_type <model_type> --input_path input --output_path output
|
65 |
+
```
|
66 |
+
where ```<model_type>``` is chosen from [dpt_beit_large_512](#model_type), [dpt_beit_large_384](#model_type),
|
67 |
+
[dpt_beit_base_384](#model_type), [dpt_swin2_large_384](#model_type), [dpt_swin2_base_384](#model_type),
|
68 |
+
[dpt_swin2_tiny_256](#model_type), [dpt_swin_large_384](#model_type), [dpt_next_vit_large_384](#model_type),
|
69 |
+
[dpt_levit_224](#model_type), [dpt_large_384](#model_type), [dpt_hybrid_384](#model_type),
|
70 |
+
[midas_v21_384](#model_type), [midas_v21_small_256](#model_type), [openvino_midas_v21_small_256](#model_type).
|
71 |
+
|
72 |
+
3) The resulting depth maps are written to the `output` folder.
|
73 |
+
|
74 |
+
#### optional
|
75 |
+
|
76 |
+
1) By default, the inference resizes the height of input images to the size of a model to fit into the encoder. This
|
77 |
+
size is given by the numbers in the model names of the [accuracy table](#accuracy). Some models do not only support a single
|
78 |
+
inference height but a range of different heights. Feel free to explore different heights by appending the extra
|
79 |
+
command line argument `--height`. Unsupported height values will throw an error. Note that using this argument may
|
80 |
+
decrease the model accuracy.
|
81 |
+
2) By default, the inference keeps the aspect ratio of input images when feeding them into the encoder if this is
|
82 |
+
supported by a model (all models except for Swin, Swin2, LeViT). In order to resize to a square resolution,
|
83 |
+
disregarding the aspect ratio while preserving the height, use the command line argument `--square`.
|
84 |
+
|
85 |
+
#### via Camera
|
86 |
+
|
87 |
+
If you want the input images to be grabbed from the camera and shown in a window, leave the input and output paths
|
88 |
+
away and choose a model type as shown above:
|
89 |
+
|
90 |
+
```shell
|
91 |
+
python run.py --model_type <model_type> --side
|
92 |
+
```
|
93 |
+
|
94 |
+
The argument `--side` is optional and causes both the input RGB image and the output depth map to be shown
|
95 |
+
side-by-side for comparison.
|
96 |
+
|
97 |
+
#### via Docker
|
98 |
+
|
99 |
+
1) Make sure you have installed Docker and the
|
100 |
+
[NVIDIA Docker runtime](https://github.com/NVIDIA/nvidia-docker/wiki/Installation-\(Native-GPU-Support\)).
|
101 |
+
|
102 |
+
2) Build the Docker image:
|
103 |
+
|
104 |
+
```shell
|
105 |
+
docker build -t midas .
|
106 |
+
```
|
107 |
+
|
108 |
+
3) Run inference:
|
109 |
+
|
110 |
+
```shell
|
111 |
+
docker run --rm --gpus all -v $PWD/input:/opt/MiDaS/input -v $PWD/output:/opt/MiDaS/output -v $PWD/weights:/opt/MiDaS/weights midas
|
112 |
+
```
|
113 |
+
|
114 |
+
This command passes through all of your NVIDIA GPUs to the container, mounts the
|
115 |
+
`input` and `output` directories and then runs the inference.
|
116 |
+
|
117 |
+
#### via PyTorch Hub
|
118 |
+
|
119 |
+
The pretrained model is also available on [PyTorch Hub](https://pytorch.org/hub/intelisl_midas_v2/)
|
120 |
+
|
121 |
+
#### via TensorFlow or ONNX
|
122 |
+
|
123 |
+
See [README](https://github.com/isl-org/MiDaS/tree/master/tf) in the `tf` subdirectory.
|
124 |
+
|
125 |
+
Currently only supports MiDaS v2.1.
|
126 |
+
|
127 |
+
|
128 |
+
#### via Mobile (iOS / Android)
|
129 |
+
|
130 |
+
See [README](https://github.com/isl-org/MiDaS/tree/master/mobile) in the `mobile` subdirectory.
|
131 |
+
|
132 |
+
#### via ROS1 (Robot Operating System)
|
133 |
+
|
134 |
+
See [README](https://github.com/isl-org/MiDaS/tree/master/ros) in the `ros` subdirectory.
|
135 |
+
|
136 |
+
Currently only supports MiDaS v2.1. DPT-based models to be added.
|
137 |
+
|
138 |
+
|
139 |
+
### Accuracy
|
140 |
+
|
141 |
+
We provide a **zero-shot error** $\epsilon_d$ which is evaluated for 6 different datasets
|
142 |
+
(see [paper](https://arxiv.org/abs/1907.01341v3)). **Lower error values are better**.
|
143 |
+
$\color{green}{\textsf{Overall model quality is represented by the improvement}}$ ([Imp.](#improvement)) with respect to
|
144 |
+
MiDaS 3.0 DPT<sub>L-384</sub>. The models are grouped by the height used for inference, whereas the square training resolution is given by
|
145 |
+
the numbers in the model names. The table also shows the **number of parameters** (in millions) and the
|
146 |
+
**frames per second** for inference at the training resolution (for GPU RTX 3090):
|
147 |
+
|
148 |
+
| MiDaS Model | DIW </br><sup>WHDR</sup> | Eth3d </br><sup>AbsRel</sup> | Sintel </br><sup>AbsRel</sup> | TUM </br><sup>δ1</sup> | KITTI </br><sup>δ1</sup> | NYUv2 </br><sup>δ1</sup> | $\color{green}{\textsf{Imp.}}$ </br><sup>%</sup> | Par.</br><sup>M</sup> | FPS</br><sup> </sup> |
|
149 |
+
|-----------------------------------------------------------------------------------------------------------------------|-------------------------:|-----------------------------:|------------------------------:|-------------------------:|-------------------------:|-------------------------:|-------------------------------------------------:|----------------------:|--------------------------:|
|
150 |
+
| **Inference height 512** | | | | | | | | | |
|
151 |
+
| [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1137 | 0.0659 | 0.2366 | **6.13** | 11.56* | **1.86*** | $\color{green}{\textsf{19}}$ | **345** | **5.7** |
|
152 |
+
| [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt)$\tiny{\square}$ | **0.1121** | **0.0614** | **0.2090** | 6.46 | **5.00*** | 1.90* | $\color{green}{\textsf{34}}$ | **345** | **5.7** |
|
153 |
+
| | | | | | | | | | |
|
154 |
+
| **Inference height 384** | | | | | | | | | |
|
155 |
+
| [v3.1 BEiT<sub>L-512</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt) | 0.1245 | 0.0681 | **0.2176** | **6.13** | 6.28* | **2.16*** | $\color{green}{\textsf{28}}$ | 345 | 12 |
|
156 |
+
| [v3.1 Swin2<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt)$\tiny{\square}$ | 0.1106 | 0.0732 | 0.2442 | 8.87 | **5.84*** | 2.92* | $\color{green}{\textsf{22}}$ | 213 | 41 |
|
157 |
+
| [v3.1 Swin2<sub>B-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt)$\tiny{\square}$ | 0.1095 | 0.0790 | 0.2404 | 8.93 | 5.97* | 3.28* | $\color{green}{\textsf{22}}$ | 102 | 39 |
|
158 |
+
| [v3.1 Swin<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt)$\tiny{\square}$ | 0.1126 | 0.0853 | 0.2428 | 8.74 | 6.60* | 3.34* | $\color{green}{\textsf{17}}$ | 213 | 49 |
|
159 |
+
| [v3.1 BEiT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt) | 0.1239 | **0.0667** | 0.2545 | 7.17 | 9.84* | 2.21* | $\color{green}{\textsf{17}}$ | 344 | 13 |
|
160 |
+
| [v3.1 Next-ViT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt) | **0.1031** | 0.0954 | 0.2295 | 9.21 | 6.89* | 3.47* | $\color{green}{\textsf{16}}$ | **72** | 30 |
|
161 |
+
| [v3.1 BEiT<sub>B-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt) | 0.1159 | 0.0967 | 0.2901 | 9.88 | 26.60* | 3.91* | $\color{green}{\textsf{-31}}$ | 112 | 31 |
|
162 |
+
| [v3.0 DPT<sub>L-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt) | 0.1082 | 0.0888 | 0.2697 | 9.97 | 8.46 | 8.32 | $\color{green}{\textsf{0}}$ | 344 | **61** |
|
163 |
+
| [v3.0 DPT<sub>H-384</sub>](https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt) | 0.1106 | 0.0934 | 0.2741 | 10.89 | 11.56 | 8.69 | $\color{green}{\textsf{-10}}$ | 123 | 50 |
|
164 |
+
| [v2.1 Large<sub>384</sub>](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt) | 0.1295 | 0.1155 | 0.3285 | 12.51 | 16.08 | 8.71 | $\color{green}{\textsf{-32}}$ | 105 | 47 |
|
165 |
+
| | | | | | | | | | |
|
166 |
+
| **Inference height 256** | | | | | | | | | |
|
167 |
+
| [v3.1 Swin2<sub>T-256</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt)$\tiny{\square}$ | **0.1211** | **0.1106** | **0.2868** | **13.43** | **10.13*** | **5.55*** | $\color{green}{\textsf{-11}}$ | 42 | 64 |
|
168 |
+
| [v2.1 Small<sub>256</sub>](https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt) | 0.1344 | 0.1344 | 0.3370 | 14.53 | 29.27 | 13.43 | $\color{green}{\textsf{-76}}$ | **21** | **90** |
|
169 |
+
| | | | | | | | | | |
|
170 |
+
| **Inference height 224** | | | | | | | | | |
|
171 |
+
| [v3.1 LeViT<sub>224</sub>](https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt)$\tiny{\square}$ | **0.1314** | **0.1206** | **0.3148** | **18.21** | **15.27*** | **8.64*** | $\color{green}{\textsf{-40}}$ | **51** | **73** |
|
172 |
+
|
173 |
+
* No zero-shot error, because models are also trained on KITTI and NYU Depth V2\
|
174 |
+
$\square$ Validation performed at **square resolution**, either because the transformer encoder backbone of a model
|
175 |
+
does not support non-square resolutions (Swin, Swin2, LeViT) or for comparison with these models. All other
|
176 |
+
validations keep the aspect ratio. A difference in resolution limits the comparability of the zero-shot error and the
|
177 |
+
improvement, because these quantities are averages over the pixels of an image and do not take into account the
|
178 |
+
advantage of more details due to a higher resolution.\
|
179 |
+
Best values per column and same validation height in bold
|
180 |
+
|
181 |
+
#### Improvement
|
182 |
+
|
183 |
+
The improvement in the above table is defined as the relative zero-shot error with respect to MiDaS v3.0
|
184 |
+
DPT<sub>L-384</sub> and averaging over the datasets. So, if $\epsilon_d$ is the zero-shot error for dataset $d$, then
|
185 |
+
the $\color{green}{\textsf{improvement}}$ is given by $100(1-(1/6)\sum_d\epsilon_d/\epsilon_{d,\rm{DPT_{L-384}}})$%.
|
186 |
+
|
187 |
+
Note that the improvements of 10% for MiDaS v2.0 → v2.1 and 21% for MiDaS v2.1 → v3.0 are not visible from the
|
188 |
+
improvement column (Imp.) in the table but would require an evaluation with respect to MiDaS v2.1 Large<sub>384</sub>
|
189 |
+
and v2.0 Large<sub>384</sub> respectively instead of v3.0 DPT<sub>L-384</sub>.
|
190 |
+
|
191 |
+
### Depth map comparison
|
192 |
+
|
193 |
+
Zoom in for better visibility
|
194 |
+

|
195 |
+
|
196 |
+
### Speed on Camera Feed
|
197 |
+
|
198 |
+
Test configuration
|
199 |
+
- Windows 10
|
200 |
+
- 11th Gen Intel Core i7-1185G7 3.00GHz
|
201 |
+
- 16GB RAM
|
202 |
+
- Camera resolution 640x480
|
203 |
+
- openvino_midas_v21_small_256
|
204 |
+
|
205 |
+
Speed: 22 FPS
|
206 |
+
|
207 |
+
### Changelog
|
208 |
+
|
209 |
+
* [Dec 2022] Released MiDaS v3.1:
|
210 |
+
- New models based on 5 different types of transformers ([BEiT](https://arxiv.org/pdf/2106.08254.pdf), [Swin2](https://arxiv.org/pdf/2111.09883.pdf), [Swin](https://arxiv.org/pdf/2103.14030.pdf), [Next-ViT](https://arxiv.org/pdf/2207.05501.pdf), [LeViT](https://arxiv.org/pdf/2104.01136.pdf))
|
211 |
+
- Training datasets extended from 10 to 12, including also KITTI and NYU Depth V2 using [BTS](https://github.com/cleinc/bts) split
|
212 |
+
- Best model, BEiT<sub>Large 512</sub>, with resolution 512x512, is on average about [28% more accurate](#Accuracy) than MiDaS v3.0
|
213 |
+
- Integrated live depth estimation from camera feed
|
214 |
+
* [Sep 2021] Integrated to [Huggingface Spaces](https://huggingface.co/spaces) with [Gradio](https://github.com/gradio-app/gradio). See [Gradio Web Demo](https://huggingface.co/spaces/akhaliq/DPT-Large).
|
215 |
+
* [Apr 2021] Released MiDaS v3.0:
|
216 |
+
- New models based on [Dense Prediction Transformers](https://arxiv.org/abs/2103.13413) are on average [21% more accurate](#Accuracy) than MiDaS v2.1
|
217 |
+
- Additional models can be found [here](https://github.com/isl-org/DPT)
|
218 |
+
* [Nov 2020] Released MiDaS v2.1:
|
219 |
+
- New model that was trained on 10 datasets and is on average about [10% more accurate](#Accuracy) than [MiDaS v2.0](https://github.com/isl-org/MiDaS/releases/tag/v2)
|
220 |
+
- New light-weight model that achieves [real-time performance](https://github.com/isl-org/MiDaS/tree/master/mobile) on mobile platforms.
|
221 |
+
- Sample applications for [iOS](https://github.com/isl-org/MiDaS/tree/master/mobile/ios) and [Android](https://github.com/isl-org/MiDaS/tree/master/mobile/android)
|
222 |
+
- [ROS package](https://github.com/isl-org/MiDaS/tree/master/ros) for easy deployment on robots
|
223 |
+
* [Jul 2020] Added TensorFlow and ONNX code. Added [online demo](http://35.202.76.57/).
|
224 |
+
* [Dec 2019] Released new version of MiDaS - the new model is significantly more accurate and robust
|
225 |
+
* [Jul 2019] Initial release of MiDaS ([Link](https://github.com/isl-org/MiDaS/releases/tag/v1))
|
226 |
+
|
227 |
+
### Citation
|
228 |
+
|
229 |
+
Please cite our paper if you use this code or any of the models:
|
230 |
+
```
|
231 |
+
@ARTICLE {Ranftl2022,
|
232 |
+
author = "Ren\'{e} Ranftl and Katrin Lasinger and David Hafner and Konrad Schindler and Vladlen Koltun",
|
233 |
+
title = "Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-Shot Cross-Dataset Transfer",
|
234 |
+
journal = "IEEE Transactions on Pattern Analysis and Machine Intelligence",
|
235 |
+
year = "2022",
|
236 |
+
volume = "44",
|
237 |
+
number = "3"
|
238 |
+
}
|
239 |
+
```
|
240 |
+
|
241 |
+
If you use a DPT-based model, please also cite:
|
242 |
+
|
243 |
+
```
|
244 |
+
@article{Ranftl2021,
|
245 |
+
author = {Ren\'{e} Ranftl and Alexey Bochkovskiy and Vladlen Koltun},
|
246 |
+
title = {Vision Transformers for Dense Prediction},
|
247 |
+
journal = {ICCV},
|
248 |
+
year = {2021},
|
249 |
+
}
|
250 |
+
```
|
251 |
+
|
252 |
+
### Acknowledgements
|
253 |
+
|
254 |
+
Our work builds on and uses code from [timm](https://github.com/rwightman/pytorch-image-models) and [Next-ViT](https://github.com/bytedance/Next-ViT).
|
255 |
+
We'd like to thank the authors for making these libraries available.
|
256 |
+
|
257 |
+
### License
|
258 |
+
|
259 |
+
MIT License
|
custom_midas_repo/__init__.py
ADDED
File without changes
|
custom_midas_repo/hubconf.py
ADDED
@@ -0,0 +1,435 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dependencies = ["torch"]
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from custom_midas_repo.midas.dpt_depth import DPTDepthModel
|
6 |
+
from custom_midas_repo.midas.midas_net import MidasNet
|
7 |
+
from custom_midas_repo.midas.midas_net_custom import MidasNet_small
|
8 |
+
|
9 |
+
def DPT_BEiT_L_512(pretrained=True, **kwargs):
|
10 |
+
""" # This docstring shows up in hub.help()
|
11 |
+
MiDaS DPT_BEiT_L_512 model for monocular depth estimation
|
12 |
+
pretrained (bool): load pretrained weights into model
|
13 |
+
"""
|
14 |
+
|
15 |
+
model = DPTDepthModel(
|
16 |
+
path=None,
|
17 |
+
backbone="beitl16_512",
|
18 |
+
non_negative=True,
|
19 |
+
)
|
20 |
+
|
21 |
+
if pretrained:
|
22 |
+
checkpoint = (
|
23 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_512.pt"
|
24 |
+
)
|
25 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
26 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
27 |
+
)
|
28 |
+
model.load_state_dict(state_dict)
|
29 |
+
|
30 |
+
return model
|
31 |
+
|
32 |
+
def DPT_BEiT_L_384(pretrained=True, **kwargs):
|
33 |
+
""" # This docstring shows up in hub.help()
|
34 |
+
MiDaS DPT_BEiT_L_384 model for monocular depth estimation
|
35 |
+
pretrained (bool): load pretrained weights into model
|
36 |
+
"""
|
37 |
+
|
38 |
+
model = DPTDepthModel(
|
39 |
+
path=None,
|
40 |
+
backbone="beitl16_384",
|
41 |
+
non_negative=True,
|
42 |
+
)
|
43 |
+
|
44 |
+
if pretrained:
|
45 |
+
checkpoint = (
|
46 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt"
|
47 |
+
)
|
48 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
49 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
50 |
+
)
|
51 |
+
model.load_state_dict(state_dict)
|
52 |
+
|
53 |
+
return model
|
54 |
+
|
55 |
+
def DPT_BEiT_B_384(pretrained=True, **kwargs):
|
56 |
+
""" # This docstring shows up in hub.help()
|
57 |
+
MiDaS DPT_BEiT_B_384 model for monocular depth estimation
|
58 |
+
pretrained (bool): load pretrained weights into model
|
59 |
+
"""
|
60 |
+
|
61 |
+
model = DPTDepthModel(
|
62 |
+
path=None,
|
63 |
+
backbone="beitb16_384",
|
64 |
+
non_negative=True,
|
65 |
+
)
|
66 |
+
|
67 |
+
if pretrained:
|
68 |
+
checkpoint = (
|
69 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_base_384.pt"
|
70 |
+
)
|
71 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
72 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
73 |
+
)
|
74 |
+
model.load_state_dict(state_dict)
|
75 |
+
|
76 |
+
return model
|
77 |
+
|
78 |
+
def DPT_SwinV2_L_384(pretrained=True, **kwargs):
|
79 |
+
""" # This docstring shows up in hub.help()
|
80 |
+
MiDaS DPT_SwinV2_L_384 model for monocular depth estimation
|
81 |
+
pretrained (bool): load pretrained weights into model
|
82 |
+
"""
|
83 |
+
|
84 |
+
model = DPTDepthModel(
|
85 |
+
path=None,
|
86 |
+
backbone="swin2l24_384",
|
87 |
+
non_negative=True,
|
88 |
+
)
|
89 |
+
|
90 |
+
if pretrained:
|
91 |
+
checkpoint = (
|
92 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_large_384.pt"
|
93 |
+
)
|
94 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
95 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
96 |
+
)
|
97 |
+
model.load_state_dict(state_dict)
|
98 |
+
|
99 |
+
return model
|
100 |
+
|
101 |
+
def DPT_SwinV2_B_384(pretrained=True, **kwargs):
|
102 |
+
""" # This docstring shows up in hub.help()
|
103 |
+
MiDaS DPT_SwinV2_B_384 model for monocular depth estimation
|
104 |
+
pretrained (bool): load pretrained weights into model
|
105 |
+
"""
|
106 |
+
|
107 |
+
model = DPTDepthModel(
|
108 |
+
path=None,
|
109 |
+
backbone="swin2b24_384",
|
110 |
+
non_negative=True,
|
111 |
+
)
|
112 |
+
|
113 |
+
if pretrained:
|
114 |
+
checkpoint = (
|
115 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_base_384.pt"
|
116 |
+
)
|
117 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
118 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
119 |
+
)
|
120 |
+
model.load_state_dict(state_dict)
|
121 |
+
|
122 |
+
return model
|
123 |
+
|
124 |
+
def DPT_SwinV2_T_256(pretrained=True, **kwargs):
|
125 |
+
""" # This docstring shows up in hub.help()
|
126 |
+
MiDaS DPT_SwinV2_T_256 model for monocular depth estimation
|
127 |
+
pretrained (bool): load pretrained weights into model
|
128 |
+
"""
|
129 |
+
|
130 |
+
model = DPTDepthModel(
|
131 |
+
path=None,
|
132 |
+
backbone="swin2t16_256",
|
133 |
+
non_negative=True,
|
134 |
+
)
|
135 |
+
|
136 |
+
if pretrained:
|
137 |
+
checkpoint = (
|
138 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin2_tiny_256.pt"
|
139 |
+
)
|
140 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
141 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
142 |
+
)
|
143 |
+
model.load_state_dict(state_dict)
|
144 |
+
|
145 |
+
return model
|
146 |
+
|
147 |
+
def DPT_Swin_L_384(pretrained=True, **kwargs):
|
148 |
+
""" # This docstring shows up in hub.help()
|
149 |
+
MiDaS DPT_Swin_L_384 model for monocular depth estimation
|
150 |
+
pretrained (bool): load pretrained weights into model
|
151 |
+
"""
|
152 |
+
|
153 |
+
model = DPTDepthModel(
|
154 |
+
path=None,
|
155 |
+
backbone="swinl12_384",
|
156 |
+
non_negative=True,
|
157 |
+
)
|
158 |
+
|
159 |
+
if pretrained:
|
160 |
+
checkpoint = (
|
161 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_swin_large_384.pt"
|
162 |
+
)
|
163 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
164 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
165 |
+
)
|
166 |
+
model.load_state_dict(state_dict)
|
167 |
+
|
168 |
+
return model
|
169 |
+
|
170 |
+
def DPT_Next_ViT_L_384(pretrained=True, **kwargs):
|
171 |
+
""" # This docstring shows up in hub.help()
|
172 |
+
MiDaS DPT_Next_ViT_L_384 model for monocular depth estimation
|
173 |
+
pretrained (bool): load pretrained weights into model
|
174 |
+
"""
|
175 |
+
|
176 |
+
model = DPTDepthModel(
|
177 |
+
path=None,
|
178 |
+
backbone="next_vit_large_6m",
|
179 |
+
non_negative=True,
|
180 |
+
)
|
181 |
+
|
182 |
+
if pretrained:
|
183 |
+
checkpoint = (
|
184 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_next_vit_large_384.pt"
|
185 |
+
)
|
186 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
187 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
188 |
+
)
|
189 |
+
model.load_state_dict(state_dict)
|
190 |
+
|
191 |
+
return model
|
192 |
+
|
193 |
+
def DPT_LeViT_224(pretrained=True, **kwargs):
|
194 |
+
""" # This docstring shows up in hub.help()
|
195 |
+
MiDaS DPT_LeViT_224 model for monocular depth estimation
|
196 |
+
pretrained (bool): load pretrained weights into model
|
197 |
+
"""
|
198 |
+
|
199 |
+
model = DPTDepthModel(
|
200 |
+
path=None,
|
201 |
+
backbone="levit_384",
|
202 |
+
non_negative=True,
|
203 |
+
head_features_1=64,
|
204 |
+
head_features_2=8,
|
205 |
+
)
|
206 |
+
|
207 |
+
if pretrained:
|
208 |
+
checkpoint = (
|
209 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_levit_224.pt"
|
210 |
+
)
|
211 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
212 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
213 |
+
)
|
214 |
+
model.load_state_dict(state_dict)
|
215 |
+
|
216 |
+
return model
|
217 |
+
|
218 |
+
def DPT_Large(pretrained=True, **kwargs):
|
219 |
+
""" # This docstring shows up in hub.help()
|
220 |
+
MiDaS DPT-Large model for monocular depth estimation
|
221 |
+
pretrained (bool): load pretrained weights into model
|
222 |
+
"""
|
223 |
+
|
224 |
+
model = DPTDepthModel(
|
225 |
+
path=None,
|
226 |
+
backbone="vitl16_384",
|
227 |
+
non_negative=True,
|
228 |
+
)
|
229 |
+
|
230 |
+
if pretrained:
|
231 |
+
checkpoint = (
|
232 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3/dpt_large_384.pt"
|
233 |
+
)
|
234 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
235 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
236 |
+
)
|
237 |
+
model.load_state_dict(state_dict)
|
238 |
+
|
239 |
+
return model
|
240 |
+
|
241 |
+
def DPT_Hybrid(pretrained=True, **kwargs):
|
242 |
+
""" # This docstring shows up in hub.help()
|
243 |
+
MiDaS DPT-Hybrid model for monocular depth estimation
|
244 |
+
pretrained (bool): load pretrained weights into model
|
245 |
+
"""
|
246 |
+
|
247 |
+
model = DPTDepthModel(
|
248 |
+
path=None,
|
249 |
+
backbone="vitb_rn50_384",
|
250 |
+
non_negative=True,
|
251 |
+
)
|
252 |
+
|
253 |
+
if pretrained:
|
254 |
+
checkpoint = (
|
255 |
+
"https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt"
|
256 |
+
)
|
257 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
258 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
259 |
+
)
|
260 |
+
model.load_state_dict(state_dict)
|
261 |
+
|
262 |
+
return model
|
263 |
+
|
264 |
+
def MiDaS(pretrained=True, **kwargs):
|
265 |
+
""" # This docstring shows up in hub.help()
|
266 |
+
MiDaS v2.1 model for monocular depth estimation
|
267 |
+
pretrained (bool): load pretrained weights into model
|
268 |
+
"""
|
269 |
+
|
270 |
+
model = MidasNet()
|
271 |
+
|
272 |
+
if pretrained:
|
273 |
+
checkpoint = (
|
274 |
+
"https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_384.pt"
|
275 |
+
)
|
276 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
277 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
278 |
+
)
|
279 |
+
model.load_state_dict(state_dict)
|
280 |
+
|
281 |
+
return model
|
282 |
+
|
283 |
+
def MiDaS_small(pretrained=True, **kwargs):
|
284 |
+
""" # This docstring shows up in hub.help()
|
285 |
+
MiDaS v2.1 small model for monocular depth estimation on resource-constrained devices
|
286 |
+
pretrained (bool): load pretrained weights into model
|
287 |
+
"""
|
288 |
+
|
289 |
+
model = MidasNet_small(None, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True})
|
290 |
+
|
291 |
+
if pretrained:
|
292 |
+
checkpoint = (
|
293 |
+
"https://github.com/isl-org/MiDaS/releases/download/v2_1/midas_v21_small_256.pt"
|
294 |
+
)
|
295 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
296 |
+
checkpoint, map_location=torch.device('cpu'), progress=True, check_hash=True
|
297 |
+
)
|
298 |
+
model.load_state_dict(state_dict)
|
299 |
+
|
300 |
+
return model
|
301 |
+
|
302 |
+
|
303 |
+
def transforms():
|
304 |
+
import cv2
|
305 |
+
from torchvision.transforms import Compose
|
306 |
+
from custom_midas_repo.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
307 |
+
from custom_midas_repo.midas import transforms
|
308 |
+
|
309 |
+
transforms.default_transform = Compose(
|
310 |
+
[
|
311 |
+
lambda img: {"image": img / 255.0},
|
312 |
+
Resize(
|
313 |
+
384,
|
314 |
+
384,
|
315 |
+
resize_target=None,
|
316 |
+
keep_aspect_ratio=True,
|
317 |
+
ensure_multiple_of=32,
|
318 |
+
resize_method="upper_bound",
|
319 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
320 |
+
),
|
321 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
322 |
+
PrepareForNet(),
|
323 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
324 |
+
]
|
325 |
+
)
|
326 |
+
|
327 |
+
transforms.small_transform = Compose(
|
328 |
+
[
|
329 |
+
lambda img: {"image": img / 255.0},
|
330 |
+
Resize(
|
331 |
+
256,
|
332 |
+
256,
|
333 |
+
resize_target=None,
|
334 |
+
keep_aspect_ratio=True,
|
335 |
+
ensure_multiple_of=32,
|
336 |
+
resize_method="upper_bound",
|
337 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
338 |
+
),
|
339 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
340 |
+
PrepareForNet(),
|
341 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
342 |
+
]
|
343 |
+
)
|
344 |
+
|
345 |
+
transforms.dpt_transform = Compose(
|
346 |
+
[
|
347 |
+
lambda img: {"image": img / 255.0},
|
348 |
+
Resize(
|
349 |
+
384,
|
350 |
+
384,
|
351 |
+
resize_target=None,
|
352 |
+
keep_aspect_ratio=True,
|
353 |
+
ensure_multiple_of=32,
|
354 |
+
resize_method="minimal",
|
355 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
356 |
+
),
|
357 |
+
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
358 |
+
PrepareForNet(),
|
359 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
360 |
+
]
|
361 |
+
)
|
362 |
+
|
363 |
+
transforms.beit512_transform = Compose(
|
364 |
+
[
|
365 |
+
lambda img: {"image": img / 255.0},
|
366 |
+
Resize(
|
367 |
+
512,
|
368 |
+
512,
|
369 |
+
resize_target=None,
|
370 |
+
keep_aspect_ratio=True,
|
371 |
+
ensure_multiple_of=32,
|
372 |
+
resize_method="minimal",
|
373 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
374 |
+
),
|
375 |
+
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
376 |
+
PrepareForNet(),
|
377 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
378 |
+
]
|
379 |
+
)
|
380 |
+
|
381 |
+
transforms.swin384_transform = Compose(
|
382 |
+
[
|
383 |
+
lambda img: {"image": img / 255.0},
|
384 |
+
Resize(
|
385 |
+
384,
|
386 |
+
384,
|
387 |
+
resize_target=None,
|
388 |
+
keep_aspect_ratio=False,
|
389 |
+
ensure_multiple_of=32,
|
390 |
+
resize_method="minimal",
|
391 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
392 |
+
),
|
393 |
+
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
394 |
+
PrepareForNet(),
|
395 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
396 |
+
]
|
397 |
+
)
|
398 |
+
|
399 |
+
transforms.swin256_transform = Compose(
|
400 |
+
[
|
401 |
+
lambda img: {"image": img / 255.0},
|
402 |
+
Resize(
|
403 |
+
256,
|
404 |
+
256,
|
405 |
+
resize_target=None,
|
406 |
+
keep_aspect_ratio=False,
|
407 |
+
ensure_multiple_of=32,
|
408 |
+
resize_method="minimal",
|
409 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
410 |
+
),
|
411 |
+
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
412 |
+
PrepareForNet(),
|
413 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
414 |
+
]
|
415 |
+
)
|
416 |
+
|
417 |
+
transforms.levit_transform = Compose(
|
418 |
+
[
|
419 |
+
lambda img: {"image": img / 255.0},
|
420 |
+
Resize(
|
421 |
+
224,
|
422 |
+
224,
|
423 |
+
resize_target=None,
|
424 |
+
keep_aspect_ratio=False,
|
425 |
+
ensure_multiple_of=32,
|
426 |
+
resize_method="minimal",
|
427 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
428 |
+
),
|
429 |
+
NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
|
430 |
+
PrepareForNet(),
|
431 |
+
lambda sample: torch.from_numpy(sample["image"]).unsqueeze(0),
|
432 |
+
]
|
433 |
+
)
|
434 |
+
|
435 |
+
return transforms
|
custom_midas_repo/midas/__init__.py
ADDED
File without changes
|
custom_midas_repo/midas/backbones/__init__.py
ADDED
File without changes
|
custom_midas_repo/midas/backbones/beit.py
ADDED
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import custom_timm as timm
|
2 |
+
import torch
|
3 |
+
import types
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .utils import forward_adapted_unflatten, make_backbone_default
|
9 |
+
from custom_timm.models.beit import gen_relative_position_index
|
10 |
+
from torch.utils.checkpoint import checkpoint
|
11 |
+
from typing import Optional
|
12 |
+
|
13 |
+
|
14 |
+
def forward_beit(pretrained, x):
|
15 |
+
return forward_adapted_unflatten(pretrained, x, "forward_features")
|
16 |
+
|
17 |
+
|
18 |
+
def patch_embed_forward(self, x):
|
19 |
+
"""
|
20 |
+
Modification of timm.models.layers.patch_embed.py: PatchEmbed.forward to support arbitrary window sizes.
|
21 |
+
"""
|
22 |
+
x = self.proj(x)
|
23 |
+
if self.flatten:
|
24 |
+
x = x.flatten(2).transpose(1, 2)
|
25 |
+
x = self.norm(x)
|
26 |
+
return x
|
27 |
+
|
28 |
+
|
29 |
+
def _get_rel_pos_bias(self, window_size):
|
30 |
+
"""
|
31 |
+
Modification of timm.models.beit.py: Attention._get_rel_pos_bias to support arbitrary window sizes.
|
32 |
+
"""
|
33 |
+
old_height = 2 * self.window_size[0] - 1
|
34 |
+
old_width = 2 * self.window_size[1] - 1
|
35 |
+
|
36 |
+
new_height = 2 * window_size[0] - 1
|
37 |
+
new_width = 2 * window_size[1] - 1
|
38 |
+
|
39 |
+
old_relative_position_bias_table = self.relative_position_bias_table
|
40 |
+
|
41 |
+
old_num_relative_distance = self.num_relative_distance
|
42 |
+
new_num_relative_distance = new_height * new_width + 3
|
43 |
+
|
44 |
+
old_sub_table = old_relative_position_bias_table[:old_num_relative_distance - 3]
|
45 |
+
|
46 |
+
old_sub_table = old_sub_table.reshape(1, old_width, old_height, -1).permute(0, 3, 1, 2)
|
47 |
+
new_sub_table = F.interpolate(old_sub_table, size=(int(new_height), int(new_width)), mode="bilinear")
|
48 |
+
new_sub_table = new_sub_table.permute(0, 2, 3, 1).reshape(new_num_relative_distance - 3, -1)
|
49 |
+
|
50 |
+
new_relative_position_bias_table = torch.cat(
|
51 |
+
[new_sub_table, old_relative_position_bias_table[old_num_relative_distance - 3:]])
|
52 |
+
|
53 |
+
key = str(window_size[1]) + "," + str(window_size[0])
|
54 |
+
if key not in self.relative_position_indices.keys():
|
55 |
+
self.relative_position_indices[key] = gen_relative_position_index(window_size)
|
56 |
+
|
57 |
+
relative_position_bias = new_relative_position_bias_table[
|
58 |
+
self.relative_position_indices[key].view(-1)].view(
|
59 |
+
window_size[0] * window_size[1] + 1,
|
60 |
+
window_size[0] * window_size[1] + 1, -1) # Wh*Ww,Wh*Ww,nH
|
61 |
+
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
|
62 |
+
return relative_position_bias.unsqueeze(0)
|
63 |
+
|
64 |
+
|
65 |
+
def attention_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
66 |
+
"""
|
67 |
+
Modification of timm.models.beit.py: Attention.forward to support arbitrary window sizes.
|
68 |
+
"""
|
69 |
+
B, N, C = x.shape
|
70 |
+
|
71 |
+
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
72 |
+
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
73 |
+
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
74 |
+
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
75 |
+
|
76 |
+
q = q * self.scale
|
77 |
+
attn = (q @ k.transpose(-2, -1))
|
78 |
+
|
79 |
+
if self.relative_position_bias_table is not None:
|
80 |
+
window_size = tuple(np.array(resolution) // 16)
|
81 |
+
attn = attn + self._get_rel_pos_bias(window_size)
|
82 |
+
if shared_rel_pos_bias is not None:
|
83 |
+
attn = attn + shared_rel_pos_bias
|
84 |
+
|
85 |
+
attn = attn.softmax(dim=-1)
|
86 |
+
attn = self.attn_drop(attn)
|
87 |
+
|
88 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
|
89 |
+
x = self.proj(x)
|
90 |
+
x = self.proj_drop(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
|
94 |
+
def block_forward(self, x, resolution, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
95 |
+
"""
|
96 |
+
Modification of timm.models.beit.py: Block.forward to support arbitrary window sizes.
|
97 |
+
"""
|
98 |
+
if self.gamma_1 is None:
|
99 |
+
x = x + self.drop_path(self.attn(self.norm1(x), resolution, shared_rel_pos_bias=shared_rel_pos_bias))
|
100 |
+
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
101 |
+
else:
|
102 |
+
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x), resolution,
|
103 |
+
shared_rel_pos_bias=shared_rel_pos_bias))
|
104 |
+
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
|
105 |
+
return x
|
106 |
+
|
107 |
+
|
108 |
+
def beit_forward_features(self, x):
|
109 |
+
"""
|
110 |
+
Modification of timm.models.beit.py: Beit.forward_features to support arbitrary window sizes.
|
111 |
+
"""
|
112 |
+
resolution = x.shape[2:]
|
113 |
+
|
114 |
+
x = self.patch_embed(x)
|
115 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
116 |
+
if self.pos_embed is not None:
|
117 |
+
x = x + self.pos_embed
|
118 |
+
x = self.pos_drop(x)
|
119 |
+
|
120 |
+
rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
|
121 |
+
for blk in self.blocks:
|
122 |
+
if self.grad_checkpointing and not torch.jit.is_scripting():
|
123 |
+
x = checkpoint(blk, x, shared_rel_pos_bias=rel_pos_bias)
|
124 |
+
else:
|
125 |
+
x = blk(x, resolution, shared_rel_pos_bias=rel_pos_bias)
|
126 |
+
x = self.norm(x)
|
127 |
+
return x
|
128 |
+
|
129 |
+
|
130 |
+
def _make_beit_backbone(
|
131 |
+
model,
|
132 |
+
features=[96, 192, 384, 768],
|
133 |
+
size=[384, 384],
|
134 |
+
hooks=[0, 4, 8, 11],
|
135 |
+
vit_features=768,
|
136 |
+
use_readout="ignore",
|
137 |
+
start_index=1,
|
138 |
+
start_index_readout=1,
|
139 |
+
):
|
140 |
+
backbone = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
|
141 |
+
start_index_readout)
|
142 |
+
|
143 |
+
backbone.model.patch_embed.forward = types.MethodType(patch_embed_forward, backbone.model.patch_embed)
|
144 |
+
backbone.model.forward_features = types.MethodType(beit_forward_features, backbone.model)
|
145 |
+
|
146 |
+
for block in backbone.model.blocks:
|
147 |
+
attn = block.attn
|
148 |
+
attn._get_rel_pos_bias = types.MethodType(_get_rel_pos_bias, attn)
|
149 |
+
attn.forward = types.MethodType(attention_forward, attn)
|
150 |
+
attn.relative_position_indices = {}
|
151 |
+
|
152 |
+
block.forward = types.MethodType(block_forward, block)
|
153 |
+
|
154 |
+
return backbone
|
155 |
+
|
156 |
+
|
157 |
+
def _make_pretrained_beitl16_512(pretrained, use_readout="ignore", hooks=None):
|
158 |
+
model = timm.create_model("beit_large_patch16_512", pretrained=pretrained)
|
159 |
+
|
160 |
+
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
161 |
+
|
162 |
+
features = [256, 512, 1024, 1024]
|
163 |
+
|
164 |
+
return _make_beit_backbone(
|
165 |
+
model,
|
166 |
+
features=features,
|
167 |
+
size=[512, 512],
|
168 |
+
hooks=hooks,
|
169 |
+
vit_features=1024,
|
170 |
+
use_readout=use_readout,
|
171 |
+
)
|
172 |
+
|
173 |
+
|
174 |
+
def _make_pretrained_beitl16_384(pretrained, use_readout="ignore", hooks=None):
|
175 |
+
model = timm.create_model("beit_large_patch16_384", pretrained=pretrained)
|
176 |
+
|
177 |
+
hooks = [5, 11, 17, 23] if hooks is None else hooks
|
178 |
+
return _make_beit_backbone(
|
179 |
+
model,
|
180 |
+
features=[256, 512, 1024, 1024],
|
181 |
+
hooks=hooks,
|
182 |
+
vit_features=1024,
|
183 |
+
use_readout=use_readout,
|
184 |
+
)
|
185 |
+
|
186 |
+
|
187 |
+
def _make_pretrained_beitb16_384(pretrained, use_readout="ignore", hooks=None):
|
188 |
+
model = timm.create_model("beit_base_patch16_384", pretrained=pretrained)
|
189 |
+
|
190 |
+
hooks = [2, 5, 8, 11] if hooks is None else hooks
|
191 |
+
return _make_beit_backbone(
|
192 |
+
model,
|
193 |
+
features=[96, 192, 384, 768],
|
194 |
+
hooks=hooks,
|
195 |
+
use_readout=use_readout,
|
196 |
+
)
|
custom_midas_repo/midas/backbones/levit.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import custom_timm as timm
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .utils import activations, get_activation, Transpose
|
7 |
+
|
8 |
+
|
9 |
+
def forward_levit(pretrained, x):
|
10 |
+
pretrained.model.forward_features(x)
|
11 |
+
|
12 |
+
layer_1 = pretrained.activations["1"]
|
13 |
+
layer_2 = pretrained.activations["2"]
|
14 |
+
layer_3 = pretrained.activations["3"]
|
15 |
+
|
16 |
+
layer_1 = pretrained.act_postprocess1(layer_1)
|
17 |
+
layer_2 = pretrained.act_postprocess2(layer_2)
|
18 |
+
layer_3 = pretrained.act_postprocess3(layer_3)
|
19 |
+
|
20 |
+
return layer_1, layer_2, layer_3
|
21 |
+
|
22 |
+
|
23 |
+
def _make_levit_backbone(
|
24 |
+
model,
|
25 |
+
hooks=[3, 11, 21],
|
26 |
+
patch_grid=[14, 14]
|
27 |
+
):
|
28 |
+
pretrained = nn.Module()
|
29 |
+
|
30 |
+
pretrained.model = model
|
31 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
32 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
33 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
34 |
+
|
35 |
+
pretrained.activations = activations
|
36 |
+
|
37 |
+
patch_grid_size = np.array(patch_grid, dtype=int)
|
38 |
+
|
39 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
40 |
+
Transpose(1, 2),
|
41 |
+
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
|
42 |
+
)
|
43 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
44 |
+
Transpose(1, 2),
|
45 |
+
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 2).astype(int)).tolist()))
|
46 |
+
)
|
47 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
48 |
+
Transpose(1, 2),
|
49 |
+
nn.Unflatten(2, torch.Size((np.ceil(patch_grid_size / 4).astype(int)).tolist()))
|
50 |
+
)
|
51 |
+
|
52 |
+
return pretrained
|
53 |
+
|
54 |
+
|
55 |
+
class ConvTransposeNorm(nn.Sequential):
|
56 |
+
"""
|
57 |
+
Modification of
|
58 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: ConvNorm
|
59 |
+
such that ConvTranspose2d is used instead of Conv2d.
|
60 |
+
"""
|
61 |
+
|
62 |
+
def __init__(
|
63 |
+
self, in_chs, out_chs, kernel_size=1, stride=1, pad=0, dilation=1,
|
64 |
+
groups=1, bn_weight_init=1):
|
65 |
+
super().__init__()
|
66 |
+
self.add_module('c',
|
67 |
+
nn.ConvTranspose2d(in_chs, out_chs, kernel_size, stride, pad, dilation, groups, bias=False))
|
68 |
+
self.add_module('bn', nn.BatchNorm2d(out_chs))
|
69 |
+
|
70 |
+
nn.init.constant_(self.bn.weight, bn_weight_init)
|
71 |
+
|
72 |
+
@torch.no_grad()
|
73 |
+
def fuse(self):
|
74 |
+
c, bn = self._modules.values()
|
75 |
+
w = bn.weight / (bn.running_var + bn.eps) ** 0.5
|
76 |
+
w = c.weight * w[:, None, None, None]
|
77 |
+
b = bn.bias - bn.running_mean * bn.weight / (bn.running_var + bn.eps) ** 0.5
|
78 |
+
m = nn.ConvTranspose2d(
|
79 |
+
w.size(1), w.size(0), w.shape[2:], stride=self.c.stride,
|
80 |
+
padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups)
|
81 |
+
m.weight.data.copy_(w)
|
82 |
+
m.bias.data.copy_(b)
|
83 |
+
return m
|
84 |
+
|
85 |
+
|
86 |
+
def stem_b4_transpose(in_chs, out_chs, activation):
|
87 |
+
"""
|
88 |
+
Modification of
|
89 |
+
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/levit.py: stem_b16
|
90 |
+
such that ConvTranspose2d is used instead of Conv2d and stem is also reduced to the half.
|
91 |
+
"""
|
92 |
+
return nn.Sequential(
|
93 |
+
ConvTransposeNorm(in_chs, out_chs, 3, 2, 1),
|
94 |
+
activation(),
|
95 |
+
ConvTransposeNorm(out_chs, out_chs // 2, 3, 2, 1),
|
96 |
+
activation())
|
97 |
+
|
98 |
+
|
99 |
+
def _make_pretrained_levit_384(pretrained, hooks=None):
|
100 |
+
model = timm.create_model("levit_384", pretrained=pretrained)
|
101 |
+
|
102 |
+
hooks = [3, 11, 21] if hooks == None else hooks
|
103 |
+
return _make_levit_backbone(
|
104 |
+
model,
|
105 |
+
hooks=hooks
|
106 |
+
)
|
custom_midas_repo/midas/backbones/next_vit.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import custom_timm as timm
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from .utils import activations, forward_default, get_activation
|
7 |
+
|
8 |
+
from ..external.next_vit.classification.nextvit import *
|
9 |
+
|
10 |
+
|
11 |
+
def forward_next_vit(pretrained, x):
|
12 |
+
return forward_default(pretrained, x, "forward")
|
13 |
+
|
14 |
+
|
15 |
+
def _make_next_vit_backbone(
|
16 |
+
model,
|
17 |
+
hooks=[2, 6, 36, 39],
|
18 |
+
):
|
19 |
+
pretrained = nn.Module()
|
20 |
+
|
21 |
+
pretrained.model = model
|
22 |
+
pretrained.model.features[hooks[0]].register_forward_hook(get_activation("1"))
|
23 |
+
pretrained.model.features[hooks[1]].register_forward_hook(get_activation("2"))
|
24 |
+
pretrained.model.features[hooks[2]].register_forward_hook(get_activation("3"))
|
25 |
+
pretrained.model.features[hooks[3]].register_forward_hook(get_activation("4"))
|
26 |
+
|
27 |
+
pretrained.activations = activations
|
28 |
+
|
29 |
+
return pretrained
|
30 |
+
|
31 |
+
|
32 |
+
def _make_pretrained_next_vit_large_6m(hooks=None):
|
33 |
+
model = timm.create_model("nextvit_large")
|
34 |
+
|
35 |
+
hooks = [2, 6, 36, 39] if hooks == None else hooks
|
36 |
+
return _make_next_vit_backbone(
|
37 |
+
model,
|
38 |
+
hooks=hooks,
|
39 |
+
)
|
custom_midas_repo/midas/backbones/swin.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import custom_timm as timm
|
2 |
+
|
3 |
+
from .swin_common import _make_swin_backbone
|
4 |
+
|
5 |
+
|
6 |
+
def _make_pretrained_swinl12_384(pretrained, hooks=None):
|
7 |
+
model = timm.create_model("swin_large_patch4_window12_384", pretrained=pretrained)
|
8 |
+
|
9 |
+
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
10 |
+
return _make_swin_backbone(
|
11 |
+
model,
|
12 |
+
hooks=hooks
|
13 |
+
)
|
custom_midas_repo/midas/backbones/swin2.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import custom_timm as timm
|
2 |
+
|
3 |
+
from .swin_common import _make_swin_backbone
|
4 |
+
|
5 |
+
|
6 |
+
def _make_pretrained_swin2l24_384(pretrained, hooks=None):
|
7 |
+
model = timm.create_model("swinv2_large_window12to24_192to384_22kft1k", pretrained=pretrained)
|
8 |
+
|
9 |
+
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
10 |
+
return _make_swin_backbone(
|
11 |
+
model,
|
12 |
+
hooks=hooks
|
13 |
+
)
|
14 |
+
|
15 |
+
|
16 |
+
def _make_pretrained_swin2b24_384(pretrained, hooks=None):
|
17 |
+
model = timm.create_model("swinv2_base_window12to24_192to384_22kft1k", pretrained=pretrained)
|
18 |
+
|
19 |
+
hooks = [1, 1, 17, 1] if hooks == None else hooks
|
20 |
+
return _make_swin_backbone(
|
21 |
+
model,
|
22 |
+
hooks=hooks
|
23 |
+
)
|
24 |
+
|
25 |
+
|
26 |
+
def _make_pretrained_swin2t16_256(pretrained, hooks=None):
|
27 |
+
model = timm.create_model("swinv2_tiny_window16_256", pretrained=pretrained)
|
28 |
+
|
29 |
+
hooks = [1, 1, 5, 1] if hooks == None else hooks
|
30 |
+
return _make_swin_backbone(
|
31 |
+
model,
|
32 |
+
hooks=hooks,
|
33 |
+
patch_grid=[64, 64]
|
34 |
+
)
|
custom_midas_repo/midas/backbones/swin_common.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from .utils import activations, forward_default, get_activation, Transpose
|
7 |
+
|
8 |
+
|
9 |
+
def forward_swin(pretrained, x):
|
10 |
+
return forward_default(pretrained, x)
|
11 |
+
|
12 |
+
|
13 |
+
def _make_swin_backbone(
|
14 |
+
model,
|
15 |
+
hooks=[1, 1, 17, 1],
|
16 |
+
patch_grid=[96, 96]
|
17 |
+
):
|
18 |
+
pretrained = nn.Module()
|
19 |
+
|
20 |
+
pretrained.model = model
|
21 |
+
pretrained.model.layers[0].blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
22 |
+
pretrained.model.layers[1].blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
23 |
+
pretrained.model.layers[2].blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
24 |
+
pretrained.model.layers[3].blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
25 |
+
|
26 |
+
pretrained.activations = activations
|
27 |
+
|
28 |
+
if hasattr(model, "patch_grid"):
|
29 |
+
used_patch_grid = model.patch_grid
|
30 |
+
else:
|
31 |
+
used_patch_grid = patch_grid
|
32 |
+
|
33 |
+
patch_grid_size = np.array(used_patch_grid, dtype=int)
|
34 |
+
|
35 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
36 |
+
Transpose(1, 2),
|
37 |
+
nn.Unflatten(2, torch.Size(patch_grid_size.tolist()))
|
38 |
+
)
|
39 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
40 |
+
Transpose(1, 2),
|
41 |
+
nn.Unflatten(2, torch.Size((patch_grid_size // 2).tolist()))
|
42 |
+
)
|
43 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
44 |
+
Transpose(1, 2),
|
45 |
+
nn.Unflatten(2, torch.Size((patch_grid_size // 4).tolist()))
|
46 |
+
)
|
47 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
48 |
+
Transpose(1, 2),
|
49 |
+
nn.Unflatten(2, torch.Size((patch_grid_size // 8).tolist()))
|
50 |
+
)
|
51 |
+
|
52 |
+
return pretrained
|
custom_midas_repo/midas/backbones/utils.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Slice(nn.Module):
|
7 |
+
def __init__(self, start_index=1):
|
8 |
+
super(Slice, self).__init__()
|
9 |
+
self.start_index = start_index
|
10 |
+
|
11 |
+
def forward(self, x):
|
12 |
+
return x[:, self.start_index:]
|
13 |
+
|
14 |
+
|
15 |
+
class AddReadout(nn.Module):
|
16 |
+
def __init__(self, start_index=1):
|
17 |
+
super(AddReadout, self).__init__()
|
18 |
+
self.start_index = start_index
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
if self.start_index == 2:
|
22 |
+
readout = (x[:, 0] + x[:, 1]) / 2
|
23 |
+
else:
|
24 |
+
readout = x[:, 0]
|
25 |
+
return x[:, self.start_index:] + readout.unsqueeze(1)
|
26 |
+
|
27 |
+
|
28 |
+
class ProjectReadout(nn.Module):
|
29 |
+
def __init__(self, in_features, start_index=1):
|
30 |
+
super(ProjectReadout, self).__init__()
|
31 |
+
self.start_index = start_index
|
32 |
+
|
33 |
+
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index:])
|
37 |
+
features = torch.cat((x[:, self.start_index:], readout), -1)
|
38 |
+
|
39 |
+
return self.project(features)
|
40 |
+
|
41 |
+
|
42 |
+
class Transpose(nn.Module):
|
43 |
+
def __init__(self, dim0, dim1):
|
44 |
+
super(Transpose, self).__init__()
|
45 |
+
self.dim0 = dim0
|
46 |
+
self.dim1 = dim1
|
47 |
+
|
48 |
+
def forward(self, x):
|
49 |
+
x = x.transpose(self.dim0, self.dim1)
|
50 |
+
return x
|
51 |
+
|
52 |
+
|
53 |
+
activations = {}
|
54 |
+
|
55 |
+
|
56 |
+
def get_activation(name):
|
57 |
+
def hook(model, input, output):
|
58 |
+
activations[name] = output
|
59 |
+
|
60 |
+
return hook
|
61 |
+
|
62 |
+
|
63 |
+
def forward_default(pretrained, x, function_name="forward_features"):
|
64 |
+
exec(f"pretrained.model.{function_name}(x)")
|
65 |
+
|
66 |
+
layer_1 = pretrained.activations["1"]
|
67 |
+
layer_2 = pretrained.activations["2"]
|
68 |
+
layer_3 = pretrained.activations["3"]
|
69 |
+
layer_4 = pretrained.activations["4"]
|
70 |
+
|
71 |
+
if hasattr(pretrained, "act_postprocess1"):
|
72 |
+
layer_1 = pretrained.act_postprocess1(layer_1)
|
73 |
+
if hasattr(pretrained, "act_postprocess2"):
|
74 |
+
layer_2 = pretrained.act_postprocess2(layer_2)
|
75 |
+
if hasattr(pretrained, "act_postprocess3"):
|
76 |
+
layer_3 = pretrained.act_postprocess3(layer_3)
|
77 |
+
if hasattr(pretrained, "act_postprocess4"):
|
78 |
+
layer_4 = pretrained.act_postprocess4(layer_4)
|
79 |
+
|
80 |
+
return layer_1, layer_2, layer_3, layer_4
|
81 |
+
|
82 |
+
|
83 |
+
def forward_adapted_unflatten(pretrained, x, function_name="forward_features"):
|
84 |
+
b, c, h, w = x.shape
|
85 |
+
|
86 |
+
exec(f"glob = pretrained.model.{function_name}(x)")
|
87 |
+
|
88 |
+
layer_1 = pretrained.activations["1"]
|
89 |
+
layer_2 = pretrained.activations["2"]
|
90 |
+
layer_3 = pretrained.activations["3"]
|
91 |
+
layer_4 = pretrained.activations["4"]
|
92 |
+
|
93 |
+
layer_1 = pretrained.act_postprocess1[0:2](layer_1)
|
94 |
+
layer_2 = pretrained.act_postprocess2[0:2](layer_2)
|
95 |
+
layer_3 = pretrained.act_postprocess3[0:2](layer_3)
|
96 |
+
layer_4 = pretrained.act_postprocess4[0:2](layer_4)
|
97 |
+
|
98 |
+
unflatten = nn.Sequential(
|
99 |
+
nn.Unflatten(
|
100 |
+
2,
|
101 |
+
torch.Size(
|
102 |
+
[
|
103 |
+
h // pretrained.model.patch_size[1],
|
104 |
+
w // pretrained.model.patch_size[0],
|
105 |
+
]
|
106 |
+
),
|
107 |
+
)
|
108 |
+
)
|
109 |
+
|
110 |
+
if layer_1.ndim == 3:
|
111 |
+
layer_1 = unflatten(layer_1)
|
112 |
+
if layer_2.ndim == 3:
|
113 |
+
layer_2 = unflatten(layer_2)
|
114 |
+
if layer_3.ndim == 3:
|
115 |
+
layer_3 = unflatten(layer_3)
|
116 |
+
if layer_4.ndim == 3:
|
117 |
+
layer_4 = unflatten(layer_4)
|
118 |
+
|
119 |
+
layer_1 = pretrained.act_postprocess1[3: len(pretrained.act_postprocess1)](layer_1)
|
120 |
+
layer_2 = pretrained.act_postprocess2[3: len(pretrained.act_postprocess2)](layer_2)
|
121 |
+
layer_3 = pretrained.act_postprocess3[3: len(pretrained.act_postprocess3)](layer_3)
|
122 |
+
layer_4 = pretrained.act_postprocess4[3: len(pretrained.act_postprocess4)](layer_4)
|
123 |
+
|
124 |
+
return layer_1, layer_2, layer_3, layer_4
|
125 |
+
|
126 |
+
|
127 |
+
def get_readout_oper(vit_features, features, use_readout, start_index=1):
|
128 |
+
if use_readout == "ignore":
|
129 |
+
readout_oper = [Slice(start_index)] * len(features)
|
130 |
+
elif use_readout == "add":
|
131 |
+
readout_oper = [AddReadout(start_index)] * len(features)
|
132 |
+
elif use_readout == "project":
|
133 |
+
readout_oper = [
|
134 |
+
ProjectReadout(vit_features, start_index) for out_feat in features
|
135 |
+
]
|
136 |
+
else:
|
137 |
+
assert (
|
138 |
+
False
|
139 |
+
), "wrong operation for readout token, use_readout can be 'ignore', 'add', or 'project'"
|
140 |
+
|
141 |
+
return readout_oper
|
142 |
+
|
143 |
+
|
144 |
+
def make_backbone_default(
|
145 |
+
model,
|
146 |
+
features=[96, 192, 384, 768],
|
147 |
+
size=[384, 384],
|
148 |
+
hooks=[2, 5, 8, 11],
|
149 |
+
vit_features=768,
|
150 |
+
use_readout="ignore",
|
151 |
+
start_index=1,
|
152 |
+
start_index_readout=1,
|
153 |
+
):
|
154 |
+
pretrained = nn.Module()
|
155 |
+
|
156 |
+
pretrained.model = model
|
157 |
+
pretrained.model.blocks[hooks[0]].register_forward_hook(get_activation("1"))
|
158 |
+
pretrained.model.blocks[hooks[1]].register_forward_hook(get_activation("2"))
|
159 |
+
pretrained.model.blocks[hooks[2]].register_forward_hook(get_activation("3"))
|
160 |
+
pretrained.model.blocks[hooks[3]].register_forward_hook(get_activation("4"))
|
161 |
+
|
162 |
+
pretrained.activations = activations
|
163 |
+
|
164 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index_readout)
|
165 |
+
|
166 |
+
# 32, 48, 136, 384
|
167 |
+
pretrained.act_postprocess1 = nn.Sequential(
|
168 |
+
readout_oper[0],
|
169 |
+
Transpose(1, 2),
|
170 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
171 |
+
nn.Conv2d(
|
172 |
+
in_channels=vit_features,
|
173 |
+
out_channels=features[0],
|
174 |
+
kernel_size=1,
|
175 |
+
stride=1,
|
176 |
+
padding=0,
|
177 |
+
),
|
178 |
+
nn.ConvTranspose2d(
|
179 |
+
in_channels=features[0],
|
180 |
+
out_channels=features[0],
|
181 |
+
kernel_size=4,
|
182 |
+
stride=4,
|
183 |
+
padding=0,
|
184 |
+
bias=True,
|
185 |
+
dilation=1,
|
186 |
+
groups=1,
|
187 |
+
),
|
188 |
+
)
|
189 |
+
|
190 |
+
pretrained.act_postprocess2 = nn.Sequential(
|
191 |
+
readout_oper[1],
|
192 |
+
Transpose(1, 2),
|
193 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
194 |
+
nn.Conv2d(
|
195 |
+
in_channels=vit_features,
|
196 |
+
out_channels=features[1],
|
197 |
+
kernel_size=1,
|
198 |
+
stride=1,
|
199 |
+
padding=0,
|
200 |
+
),
|
201 |
+
nn.ConvTranspose2d(
|
202 |
+
in_channels=features[1],
|
203 |
+
out_channels=features[1],
|
204 |
+
kernel_size=2,
|
205 |
+
stride=2,
|
206 |
+
padding=0,
|
207 |
+
bias=True,
|
208 |
+
dilation=1,
|
209 |
+
groups=1,
|
210 |
+
),
|
211 |
+
)
|
212 |
+
|
213 |
+
pretrained.act_postprocess3 = nn.Sequential(
|
214 |
+
readout_oper[2],
|
215 |
+
Transpose(1, 2),
|
216 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
217 |
+
nn.Conv2d(
|
218 |
+
in_channels=vit_features,
|
219 |
+
out_channels=features[2],
|
220 |
+
kernel_size=1,
|
221 |
+
stride=1,
|
222 |
+
padding=0,
|
223 |
+
),
|
224 |
+
)
|
225 |
+
|
226 |
+
pretrained.act_postprocess4 = nn.Sequential(
|
227 |
+
readout_oper[3],
|
228 |
+
Transpose(1, 2),
|
229 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
230 |
+
nn.Conv2d(
|
231 |
+
in_channels=vit_features,
|
232 |
+
out_channels=features[3],
|
233 |
+
kernel_size=1,
|
234 |
+
stride=1,
|
235 |
+
padding=0,
|
236 |
+
),
|
237 |
+
nn.Conv2d(
|
238 |
+
in_channels=features[3],
|
239 |
+
out_channels=features[3],
|
240 |
+
kernel_size=3,
|
241 |
+
stride=2,
|
242 |
+
padding=1,
|
243 |
+
),
|
244 |
+
)
|
245 |
+
|
246 |
+
pretrained.model.start_index = start_index
|
247 |
+
pretrained.model.patch_size = [16, 16]
|
248 |
+
|
249 |
+
return pretrained
|
custom_midas_repo/midas/backbones/vit.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import custom_timm as timm
|
4 |
+
import types
|
5 |
+
import math
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from .utils import (activations, forward_adapted_unflatten, get_activation, get_readout_oper,
|
9 |
+
make_backbone_default, Transpose)
|
10 |
+
|
11 |
+
|
12 |
+
def forward_vit(pretrained, x):
|
13 |
+
return forward_adapted_unflatten(pretrained, x, "forward_flex")
|
14 |
+
|
15 |
+
|
16 |
+
def _resize_pos_embed(self, posemb, gs_h, gs_w):
|
17 |
+
posemb_tok, posemb_grid = (
|
18 |
+
posemb[:, : self.start_index],
|
19 |
+
posemb[0, self.start_index:],
|
20 |
+
)
|
21 |
+
|
22 |
+
gs_old = int(math.sqrt(len(posemb_grid)))
|
23 |
+
|
24 |
+
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
25 |
+
posemb_grid = F.interpolate(posemb_grid, size=(gs_h, gs_w), mode="bilinear")
|
26 |
+
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_h * gs_w, -1)
|
27 |
+
|
28 |
+
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
29 |
+
|
30 |
+
return posemb
|
31 |
+
|
32 |
+
|
33 |
+
def forward_flex(self, x):
|
34 |
+
b, c, h, w = x.shape
|
35 |
+
|
36 |
+
pos_embed = self._resize_pos_embed(
|
37 |
+
self.pos_embed, h // self.patch_size[1], w // self.patch_size[0]
|
38 |
+
)
|
39 |
+
|
40 |
+
B = x.shape[0]
|
41 |
+
|
42 |
+
if hasattr(self.patch_embed, "backbone"):
|
43 |
+
x = self.patch_embed.backbone(x)
|
44 |
+
if isinstance(x, (list, tuple)):
|
45 |
+
x = x[-1] # last feature if backbone outputs list/tuple of features
|
46 |
+
|
47 |
+
x = self.patch_embed.proj(x).flatten(2).transpose(1, 2)
|
48 |
+
|
49 |
+
if getattr(self, "dist_token", None) is not None:
|
50 |
+
cls_tokens = self.cls_token.expand(
|
51 |
+
B, -1, -1
|
52 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
53 |
+
dist_token = self.dist_token.expand(B, -1, -1)
|
54 |
+
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
55 |
+
else:
|
56 |
+
if self.no_embed_class:
|
57 |
+
x = x + pos_embed
|
58 |
+
cls_tokens = self.cls_token.expand(
|
59 |
+
B, -1, -1
|
60 |
+
) # stole cls_tokens impl from Phil Wang, thanks
|
61 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
62 |
+
|
63 |
+
if not self.no_embed_class:
|
64 |
+
x = x + pos_embed
|
65 |
+
x = self.pos_drop(x)
|
66 |
+
|
67 |
+
for blk in self.blocks:
|
68 |
+
x = blk(x)
|
69 |
+
|
70 |
+
x = self.norm(x)
|
71 |
+
|
72 |
+
return x
|
73 |
+
|
74 |
+
|
75 |
+
def _make_vit_b16_backbone(
|
76 |
+
model,
|
77 |
+
features=[96, 192, 384, 768],
|
78 |
+
size=[384, 384],
|
79 |
+
hooks=[2, 5, 8, 11],
|
80 |
+
vit_features=768,
|
81 |
+
use_readout="ignore",
|
82 |
+
start_index=1,
|
83 |
+
start_index_readout=1,
|
84 |
+
):
|
85 |
+
pretrained = make_backbone_default(model, features, size, hooks, vit_features, use_readout, start_index,
|
86 |
+
start_index_readout)
|
87 |
+
|
88 |
+
# We inject this function into the VisionTransformer instances so that
|
89 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
90 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
91 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
92 |
+
_resize_pos_embed, pretrained.model
|
93 |
+
)
|
94 |
+
|
95 |
+
return pretrained
|
96 |
+
|
97 |
+
|
98 |
+
def _make_pretrained_vitl16_384(pretrained, use_readout="ignore", hooks=None):
|
99 |
+
model = timm.create_model("vit_large_patch16_384", pretrained=pretrained)
|
100 |
+
|
101 |
+
hooks = [5, 11, 17, 23] if hooks == None else hooks
|
102 |
+
return _make_vit_b16_backbone(
|
103 |
+
model,
|
104 |
+
features=[256, 512, 1024, 1024],
|
105 |
+
hooks=hooks,
|
106 |
+
vit_features=1024,
|
107 |
+
use_readout=use_readout,
|
108 |
+
)
|
109 |
+
|
110 |
+
|
111 |
+
def _make_pretrained_vitb16_384(pretrained, use_readout="ignore", hooks=None):
|
112 |
+
model = timm.create_model("vit_base_patch16_384", pretrained=pretrained)
|
113 |
+
|
114 |
+
hooks = [2, 5, 8, 11] if hooks == None else hooks
|
115 |
+
return _make_vit_b16_backbone(
|
116 |
+
model, features=[96, 192, 384, 768], hooks=hooks, use_readout=use_readout
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def _make_vit_b_rn50_backbone(
|
121 |
+
model,
|
122 |
+
features=[256, 512, 768, 768],
|
123 |
+
size=[384, 384],
|
124 |
+
hooks=[0, 1, 8, 11],
|
125 |
+
vit_features=768,
|
126 |
+
patch_size=[16, 16],
|
127 |
+
number_stages=2,
|
128 |
+
use_vit_only=False,
|
129 |
+
use_readout="ignore",
|
130 |
+
start_index=1,
|
131 |
+
):
|
132 |
+
pretrained = nn.Module()
|
133 |
+
|
134 |
+
pretrained.model = model
|
135 |
+
|
136 |
+
used_number_stages = 0 if use_vit_only else number_stages
|
137 |
+
for s in range(used_number_stages):
|
138 |
+
pretrained.model.patch_embed.backbone.stages[s].register_forward_hook(
|
139 |
+
get_activation(str(s + 1))
|
140 |
+
)
|
141 |
+
for s in range(used_number_stages, 4):
|
142 |
+
pretrained.model.blocks[hooks[s]].register_forward_hook(get_activation(str(s + 1)))
|
143 |
+
|
144 |
+
pretrained.activations = activations
|
145 |
+
|
146 |
+
readout_oper = get_readout_oper(vit_features, features, use_readout, start_index)
|
147 |
+
|
148 |
+
for s in range(used_number_stages):
|
149 |
+
value = nn.Sequential(nn.Identity(), nn.Identity(), nn.Identity())
|
150 |
+
exec(f"pretrained.act_postprocess{s + 1}=value")
|
151 |
+
for s in range(used_number_stages, 4):
|
152 |
+
if s < number_stages:
|
153 |
+
final_layer = nn.ConvTranspose2d(
|
154 |
+
in_channels=features[s],
|
155 |
+
out_channels=features[s],
|
156 |
+
kernel_size=4 // (2 ** s),
|
157 |
+
stride=4 // (2 ** s),
|
158 |
+
padding=0,
|
159 |
+
bias=True,
|
160 |
+
dilation=1,
|
161 |
+
groups=1,
|
162 |
+
)
|
163 |
+
elif s > number_stages:
|
164 |
+
final_layer = nn.Conv2d(
|
165 |
+
in_channels=features[3],
|
166 |
+
out_channels=features[3],
|
167 |
+
kernel_size=3,
|
168 |
+
stride=2,
|
169 |
+
padding=1,
|
170 |
+
)
|
171 |
+
else:
|
172 |
+
final_layer = None
|
173 |
+
|
174 |
+
layers = [
|
175 |
+
readout_oper[s],
|
176 |
+
Transpose(1, 2),
|
177 |
+
nn.Unflatten(2, torch.Size([size[0] // 16, size[1] // 16])),
|
178 |
+
nn.Conv2d(
|
179 |
+
in_channels=vit_features,
|
180 |
+
out_channels=features[s],
|
181 |
+
kernel_size=1,
|
182 |
+
stride=1,
|
183 |
+
padding=0,
|
184 |
+
),
|
185 |
+
]
|
186 |
+
if final_layer is not None:
|
187 |
+
layers.append(final_layer)
|
188 |
+
|
189 |
+
value = nn.Sequential(*layers)
|
190 |
+
exec(f"pretrained.act_postprocess{s + 1}=value")
|
191 |
+
|
192 |
+
pretrained.model.start_index = start_index
|
193 |
+
pretrained.model.patch_size = patch_size
|
194 |
+
|
195 |
+
# We inject this function into the VisionTransformer instances so that
|
196 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
197 |
+
pretrained.model.forward_flex = types.MethodType(forward_flex, pretrained.model)
|
198 |
+
|
199 |
+
# We inject this function into the VisionTransformer instances so that
|
200 |
+
# we can use it with interpolated position embeddings without modifying the library source.
|
201 |
+
pretrained.model._resize_pos_embed = types.MethodType(
|
202 |
+
_resize_pos_embed, pretrained.model
|
203 |
+
)
|
204 |
+
|
205 |
+
return pretrained
|
206 |
+
|
207 |
+
|
208 |
+
def _make_pretrained_vitb_rn50_384(
|
209 |
+
pretrained, use_readout="ignore", hooks=None, use_vit_only=False
|
210 |
+
):
|
211 |
+
model = timm.create_model("vit_base_resnet50_384", pretrained=pretrained)
|
212 |
+
|
213 |
+
hooks = [0, 1, 8, 11] if hooks == None else hooks
|
214 |
+
return _make_vit_b_rn50_backbone(
|
215 |
+
model,
|
216 |
+
features=[256, 512, 768, 768],
|
217 |
+
size=[384, 384],
|
218 |
+
hooks=hooks,
|
219 |
+
use_vit_only=use_vit_only,
|
220 |
+
use_readout=use_readout,
|
221 |
+
)
|
custom_midas_repo/midas/base_model.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class BaseModel(torch.nn.Module):
|
5 |
+
def load(self, path):
|
6 |
+
"""Load model from file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
path (str): file path
|
10 |
+
"""
|
11 |
+
parameters = torch.load(path, map_location=torch.device('cpu'))
|
12 |
+
|
13 |
+
if "optimizer" in parameters:
|
14 |
+
parameters = parameters["model"]
|
15 |
+
|
16 |
+
self.load_state_dict(parameters)
|
custom_midas_repo/midas/blocks.py
ADDED
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .backbones.beit import (
|
5 |
+
_make_pretrained_beitl16_512,
|
6 |
+
_make_pretrained_beitl16_384,
|
7 |
+
_make_pretrained_beitb16_384,
|
8 |
+
forward_beit,
|
9 |
+
)
|
10 |
+
from .backbones.swin_common import (
|
11 |
+
forward_swin,
|
12 |
+
)
|
13 |
+
from .backbones.swin2 import (
|
14 |
+
_make_pretrained_swin2l24_384,
|
15 |
+
_make_pretrained_swin2b24_384,
|
16 |
+
_make_pretrained_swin2t16_256,
|
17 |
+
)
|
18 |
+
from .backbones.swin import (
|
19 |
+
_make_pretrained_swinl12_384,
|
20 |
+
)
|
21 |
+
from .backbones.levit import (
|
22 |
+
_make_pretrained_levit_384,
|
23 |
+
forward_levit,
|
24 |
+
)
|
25 |
+
from .backbones.vit import (
|
26 |
+
_make_pretrained_vitb_rn50_384,
|
27 |
+
_make_pretrained_vitl16_384,
|
28 |
+
_make_pretrained_vitb16_384,
|
29 |
+
forward_vit,
|
30 |
+
)
|
31 |
+
|
32 |
+
def _make_encoder(backbone, features, use_pretrained, groups=1, expand=False, exportable=True, hooks=None,
|
33 |
+
use_vit_only=False, use_readout="ignore", in_features=[96, 256, 512, 1024]):
|
34 |
+
if backbone == "beitl16_512":
|
35 |
+
pretrained = _make_pretrained_beitl16_512(
|
36 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
37 |
+
)
|
38 |
+
scratch = _make_scratch(
|
39 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
40 |
+
) # BEiT_512-L (backbone)
|
41 |
+
elif backbone == "beitl16_384":
|
42 |
+
pretrained = _make_pretrained_beitl16_384(
|
43 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
44 |
+
)
|
45 |
+
scratch = _make_scratch(
|
46 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
47 |
+
) # BEiT_384-L (backbone)
|
48 |
+
elif backbone == "beitb16_384":
|
49 |
+
pretrained = _make_pretrained_beitb16_384(
|
50 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
51 |
+
)
|
52 |
+
scratch = _make_scratch(
|
53 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
54 |
+
) # BEiT_384-B (backbone)
|
55 |
+
elif backbone == "swin2l24_384":
|
56 |
+
pretrained = _make_pretrained_swin2l24_384(
|
57 |
+
use_pretrained, hooks=hooks
|
58 |
+
)
|
59 |
+
scratch = _make_scratch(
|
60 |
+
[192, 384, 768, 1536], features, groups=groups, expand=expand
|
61 |
+
) # Swin2-L/12to24 (backbone)
|
62 |
+
elif backbone == "swin2b24_384":
|
63 |
+
pretrained = _make_pretrained_swin2b24_384(
|
64 |
+
use_pretrained, hooks=hooks
|
65 |
+
)
|
66 |
+
scratch = _make_scratch(
|
67 |
+
[128, 256, 512, 1024], features, groups=groups, expand=expand
|
68 |
+
) # Swin2-B/12to24 (backbone)
|
69 |
+
elif backbone == "swin2t16_256":
|
70 |
+
pretrained = _make_pretrained_swin2t16_256(
|
71 |
+
use_pretrained, hooks=hooks
|
72 |
+
)
|
73 |
+
scratch = _make_scratch(
|
74 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
75 |
+
) # Swin2-T/16 (backbone)
|
76 |
+
elif backbone == "swinl12_384":
|
77 |
+
pretrained = _make_pretrained_swinl12_384(
|
78 |
+
use_pretrained, hooks=hooks
|
79 |
+
)
|
80 |
+
scratch = _make_scratch(
|
81 |
+
[192, 384, 768, 1536], features, groups=groups, expand=expand
|
82 |
+
) # Swin-L/12 (backbone)
|
83 |
+
elif backbone == "next_vit_large_6m":
|
84 |
+
from .backbones.next_vit import _make_pretrained_next_vit_large_6m
|
85 |
+
pretrained = _make_pretrained_next_vit_large_6m(hooks=hooks)
|
86 |
+
scratch = _make_scratch(
|
87 |
+
in_features, features, groups=groups, expand=expand
|
88 |
+
) # Next-ViT-L on ImageNet-1K-6M (backbone)
|
89 |
+
elif backbone == "levit_384":
|
90 |
+
pretrained = _make_pretrained_levit_384(
|
91 |
+
use_pretrained, hooks=hooks
|
92 |
+
)
|
93 |
+
scratch = _make_scratch(
|
94 |
+
[384, 512, 768], features, groups=groups, expand=expand
|
95 |
+
) # LeViT 384 (backbone)
|
96 |
+
elif backbone == "vitl16_384":
|
97 |
+
pretrained = _make_pretrained_vitl16_384(
|
98 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
99 |
+
)
|
100 |
+
scratch = _make_scratch(
|
101 |
+
[256, 512, 1024, 1024], features, groups=groups, expand=expand
|
102 |
+
) # ViT-L/16 - 85.0% Top1 (backbone)
|
103 |
+
elif backbone == "vitb_rn50_384":
|
104 |
+
pretrained = _make_pretrained_vitb_rn50_384(
|
105 |
+
use_pretrained,
|
106 |
+
hooks=hooks,
|
107 |
+
use_vit_only=use_vit_only,
|
108 |
+
use_readout=use_readout,
|
109 |
+
)
|
110 |
+
scratch = _make_scratch(
|
111 |
+
[256, 512, 768, 768], features, groups=groups, expand=expand
|
112 |
+
) # ViT-H/16 - 85.0% Top1 (backbone)
|
113 |
+
elif backbone == "vitb16_384":
|
114 |
+
pretrained = _make_pretrained_vitb16_384(
|
115 |
+
use_pretrained, hooks=hooks, use_readout=use_readout
|
116 |
+
)
|
117 |
+
scratch = _make_scratch(
|
118 |
+
[96, 192, 384, 768], features, groups=groups, expand=expand
|
119 |
+
) # ViT-B/16 - 84.6% Top1 (backbone)
|
120 |
+
elif backbone == "resnext101_wsl":
|
121 |
+
pretrained = _make_pretrained_resnext101_wsl(use_pretrained)
|
122 |
+
scratch = _make_scratch([256, 512, 1024, 2048], features, groups=groups, expand=expand) # efficientnet_lite3
|
123 |
+
elif backbone == "efficientnet_lite3":
|
124 |
+
pretrained = _make_pretrained_efficientnet_lite3(use_pretrained, exportable=exportable)
|
125 |
+
scratch = _make_scratch([32, 48, 136, 384], features, groups=groups, expand=expand) # efficientnet_lite3
|
126 |
+
else:
|
127 |
+
print(f"Backbone '{backbone}' not implemented")
|
128 |
+
assert False
|
129 |
+
|
130 |
+
return pretrained, scratch
|
131 |
+
|
132 |
+
|
133 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
134 |
+
scratch = nn.Module()
|
135 |
+
|
136 |
+
out_shape1 = out_shape
|
137 |
+
out_shape2 = out_shape
|
138 |
+
out_shape3 = out_shape
|
139 |
+
if len(in_shape) >= 4:
|
140 |
+
out_shape4 = out_shape
|
141 |
+
|
142 |
+
if expand:
|
143 |
+
out_shape1 = out_shape
|
144 |
+
out_shape2 = out_shape*2
|
145 |
+
out_shape3 = out_shape*4
|
146 |
+
if len(in_shape) >= 4:
|
147 |
+
out_shape4 = out_shape*8
|
148 |
+
|
149 |
+
scratch.layer1_rn = nn.Conv2d(
|
150 |
+
in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
151 |
+
)
|
152 |
+
scratch.layer2_rn = nn.Conv2d(
|
153 |
+
in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
154 |
+
)
|
155 |
+
scratch.layer3_rn = nn.Conv2d(
|
156 |
+
in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
157 |
+
)
|
158 |
+
if len(in_shape) >= 4:
|
159 |
+
scratch.layer4_rn = nn.Conv2d(
|
160 |
+
in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
|
161 |
+
)
|
162 |
+
|
163 |
+
return scratch
|
164 |
+
|
165 |
+
|
166 |
+
def _make_pretrained_efficientnet_lite3(use_pretrained, exportable=False):
|
167 |
+
efficientnet = torch.hub.load(
|
168 |
+
"rwightman/gen-efficientnet-pytorch",
|
169 |
+
"tf_efficientnet_lite3",
|
170 |
+
pretrained=use_pretrained,
|
171 |
+
exportable=exportable
|
172 |
+
)
|
173 |
+
return _make_efficientnet_backbone(efficientnet)
|
174 |
+
|
175 |
+
|
176 |
+
def _make_efficientnet_backbone(effnet):
|
177 |
+
pretrained = nn.Module()
|
178 |
+
|
179 |
+
pretrained.layer1 = nn.Sequential(
|
180 |
+
effnet.conv_stem, effnet.bn1, effnet.act1, *effnet.blocks[0:2]
|
181 |
+
)
|
182 |
+
pretrained.layer2 = nn.Sequential(*effnet.blocks[2:3])
|
183 |
+
pretrained.layer3 = nn.Sequential(*effnet.blocks[3:5])
|
184 |
+
pretrained.layer4 = nn.Sequential(*effnet.blocks[5:9])
|
185 |
+
|
186 |
+
return pretrained
|
187 |
+
|
188 |
+
|
189 |
+
def _make_resnet_backbone(resnet):
|
190 |
+
pretrained = nn.Module()
|
191 |
+
pretrained.layer1 = nn.Sequential(
|
192 |
+
resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1
|
193 |
+
)
|
194 |
+
|
195 |
+
pretrained.layer2 = resnet.layer2
|
196 |
+
pretrained.layer3 = resnet.layer3
|
197 |
+
pretrained.layer4 = resnet.layer4
|
198 |
+
|
199 |
+
return pretrained
|
200 |
+
|
201 |
+
|
202 |
+
def _make_pretrained_resnext101_wsl(use_pretrained):
|
203 |
+
resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl")
|
204 |
+
return _make_resnet_backbone(resnet)
|
205 |
+
|
206 |
+
|
207 |
+
|
208 |
+
class Interpolate(nn.Module):
|
209 |
+
"""Interpolation module.
|
210 |
+
"""
|
211 |
+
|
212 |
+
def __init__(self, scale_factor, mode, align_corners=False):
|
213 |
+
"""Init.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
scale_factor (float): scaling
|
217 |
+
mode (str): interpolation mode
|
218 |
+
"""
|
219 |
+
super(Interpolate, self).__init__()
|
220 |
+
|
221 |
+
self.interp = nn.functional.interpolate
|
222 |
+
self.scale_factor = scale_factor
|
223 |
+
self.mode = mode
|
224 |
+
self.align_corners = align_corners
|
225 |
+
|
226 |
+
def forward(self, x):
|
227 |
+
"""Forward pass.
|
228 |
+
|
229 |
+
Args:
|
230 |
+
x (tensor): input
|
231 |
+
|
232 |
+
Returns:
|
233 |
+
tensor: interpolated data
|
234 |
+
"""
|
235 |
+
|
236 |
+
x = self.interp(
|
237 |
+
x, scale_factor=self.scale_factor, mode=self.mode, align_corners=self.align_corners
|
238 |
+
)
|
239 |
+
|
240 |
+
return x
|
241 |
+
|
242 |
+
|
243 |
+
class ResidualConvUnit(nn.Module):
|
244 |
+
"""Residual convolution module.
|
245 |
+
"""
|
246 |
+
|
247 |
+
def __init__(self, features):
|
248 |
+
"""Init.
|
249 |
+
|
250 |
+
Args:
|
251 |
+
features (int): number of features
|
252 |
+
"""
|
253 |
+
super().__init__()
|
254 |
+
|
255 |
+
self.conv1 = nn.Conv2d(
|
256 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
257 |
+
)
|
258 |
+
|
259 |
+
self.conv2 = nn.Conv2d(
|
260 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True
|
261 |
+
)
|
262 |
+
|
263 |
+
self.relu = nn.ReLU(inplace=True)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
"""Forward pass.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
x (tensor): input
|
270 |
+
|
271 |
+
Returns:
|
272 |
+
tensor: output
|
273 |
+
"""
|
274 |
+
out = self.relu(x)
|
275 |
+
out = self.conv1(out)
|
276 |
+
out = self.relu(out)
|
277 |
+
out = self.conv2(out)
|
278 |
+
|
279 |
+
return out + x
|
280 |
+
|
281 |
+
|
282 |
+
class FeatureFusionBlock(nn.Module):
|
283 |
+
"""Feature fusion block.
|
284 |
+
"""
|
285 |
+
|
286 |
+
def __init__(self, features):
|
287 |
+
"""Init.
|
288 |
+
|
289 |
+
Args:
|
290 |
+
features (int): number of features
|
291 |
+
"""
|
292 |
+
super(FeatureFusionBlock, self).__init__()
|
293 |
+
|
294 |
+
self.resConfUnit1 = ResidualConvUnit(features)
|
295 |
+
self.resConfUnit2 = ResidualConvUnit(features)
|
296 |
+
|
297 |
+
def forward(self, *xs):
|
298 |
+
"""Forward pass.
|
299 |
+
|
300 |
+
Returns:
|
301 |
+
tensor: output
|
302 |
+
"""
|
303 |
+
output = xs[0]
|
304 |
+
|
305 |
+
if len(xs) == 2:
|
306 |
+
output += self.resConfUnit1(xs[1])
|
307 |
+
|
308 |
+
output = self.resConfUnit2(output)
|
309 |
+
|
310 |
+
output = nn.functional.interpolate(
|
311 |
+
output, scale_factor=2, mode="bilinear", align_corners=True
|
312 |
+
)
|
313 |
+
|
314 |
+
return output
|
315 |
+
|
316 |
+
|
317 |
+
|
318 |
+
|
319 |
+
class ResidualConvUnit_custom(nn.Module):
|
320 |
+
"""Residual convolution module.
|
321 |
+
"""
|
322 |
+
|
323 |
+
def __init__(self, features, activation, bn):
|
324 |
+
"""Init.
|
325 |
+
|
326 |
+
Args:
|
327 |
+
features (int): number of features
|
328 |
+
"""
|
329 |
+
super().__init__()
|
330 |
+
|
331 |
+
self.bn = bn
|
332 |
+
|
333 |
+
self.groups=1
|
334 |
+
|
335 |
+
self.conv1 = nn.Conv2d(
|
336 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
337 |
+
)
|
338 |
+
|
339 |
+
self.conv2 = nn.Conv2d(
|
340 |
+
features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups
|
341 |
+
)
|
342 |
+
|
343 |
+
if self.bn==True:
|
344 |
+
self.bn1 = nn.BatchNorm2d(features)
|
345 |
+
self.bn2 = nn.BatchNorm2d(features)
|
346 |
+
|
347 |
+
self.activation = activation
|
348 |
+
|
349 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
350 |
+
|
351 |
+
def forward(self, x):
|
352 |
+
"""Forward pass.
|
353 |
+
|
354 |
+
Args:
|
355 |
+
x (tensor): input
|
356 |
+
|
357 |
+
Returns:
|
358 |
+
tensor: output
|
359 |
+
"""
|
360 |
+
|
361 |
+
out = self.activation(x)
|
362 |
+
out = self.conv1(out)
|
363 |
+
if self.bn==True:
|
364 |
+
out = self.bn1(out)
|
365 |
+
|
366 |
+
out = self.activation(out)
|
367 |
+
out = self.conv2(out)
|
368 |
+
if self.bn==True:
|
369 |
+
out = self.bn2(out)
|
370 |
+
|
371 |
+
if self.groups > 1:
|
372 |
+
out = self.conv_merge(out)
|
373 |
+
|
374 |
+
return self.skip_add.add(out, x)
|
375 |
+
|
376 |
+
# return out + x
|
377 |
+
|
378 |
+
|
379 |
+
class FeatureFusionBlock_custom(nn.Module):
|
380 |
+
"""Feature fusion block.
|
381 |
+
"""
|
382 |
+
|
383 |
+
def __init__(self, features, activation, deconv=False, bn=False, expand=False, align_corners=True, size=None):
|
384 |
+
"""Init.
|
385 |
+
|
386 |
+
Args:
|
387 |
+
features (int): number of features
|
388 |
+
"""
|
389 |
+
super(FeatureFusionBlock_custom, self).__init__()
|
390 |
+
|
391 |
+
self.deconv = deconv
|
392 |
+
self.align_corners = align_corners
|
393 |
+
|
394 |
+
self.groups=1
|
395 |
+
|
396 |
+
self.expand = expand
|
397 |
+
out_features = features
|
398 |
+
if self.expand==True:
|
399 |
+
out_features = features//2
|
400 |
+
|
401 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
402 |
+
|
403 |
+
self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn)
|
404 |
+
self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn)
|
405 |
+
|
406 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
407 |
+
|
408 |
+
self.size=size
|
409 |
+
|
410 |
+
def forward(self, *xs, size=None):
|
411 |
+
"""Forward pass.
|
412 |
+
|
413 |
+
Returns:
|
414 |
+
tensor: output
|
415 |
+
"""
|
416 |
+
output = xs[0]
|
417 |
+
|
418 |
+
if len(xs) == 2:
|
419 |
+
res = self.resConfUnit1(xs[1])
|
420 |
+
output = self.skip_add.add(output, res)
|
421 |
+
# output += res
|
422 |
+
|
423 |
+
output = self.resConfUnit2(output)
|
424 |
+
|
425 |
+
if (size is None) and (self.size is None):
|
426 |
+
modifier = {"scale_factor": 2}
|
427 |
+
elif size is None:
|
428 |
+
modifier = {"size": self.size}
|
429 |
+
else:
|
430 |
+
modifier = {"size": size}
|
431 |
+
|
432 |
+
output = nn.functional.interpolate(
|
433 |
+
output, **modifier, mode="bilinear", align_corners=self.align_corners
|
434 |
+
)
|
435 |
+
|
436 |
+
output = self.out_conv(output)
|
437 |
+
|
438 |
+
return output
|
439 |
+
|
custom_midas_repo/midas/dpt_depth.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .base_model import BaseModel
|
5 |
+
from .blocks import (
|
6 |
+
FeatureFusionBlock_custom,
|
7 |
+
Interpolate,
|
8 |
+
_make_encoder,
|
9 |
+
forward_beit,
|
10 |
+
forward_swin,
|
11 |
+
forward_levit,
|
12 |
+
forward_vit,
|
13 |
+
)
|
14 |
+
from .backbones.levit import stem_b4_transpose
|
15 |
+
from custom_timm.models.layers import get_act_layer
|
16 |
+
|
17 |
+
|
18 |
+
def _make_fusion_block(features, use_bn, size = None):
|
19 |
+
return FeatureFusionBlock_custom(
|
20 |
+
features,
|
21 |
+
nn.ReLU(False),
|
22 |
+
deconv=False,
|
23 |
+
bn=use_bn,
|
24 |
+
expand=False,
|
25 |
+
align_corners=True,
|
26 |
+
size=size,
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
class DPT(BaseModel):
|
31 |
+
def __init__(
|
32 |
+
self,
|
33 |
+
head,
|
34 |
+
features=256,
|
35 |
+
backbone="vitb_rn50_384",
|
36 |
+
readout="project",
|
37 |
+
channels_last=False,
|
38 |
+
use_bn=False,
|
39 |
+
**kwargs
|
40 |
+
):
|
41 |
+
|
42 |
+
super(DPT, self).__init__()
|
43 |
+
|
44 |
+
self.channels_last = channels_last
|
45 |
+
|
46 |
+
# For the Swin, Swin 2, LeViT and Next-ViT Transformers, the hierarchical architectures prevent setting the
|
47 |
+
# hooks freely. Instead, the hooks have to be chosen according to the ranges specified in the comments.
|
48 |
+
hooks = {
|
49 |
+
"beitl16_512": [5, 11, 17, 23],
|
50 |
+
"beitl16_384": [5, 11, 17, 23],
|
51 |
+
"beitb16_384": [2, 5, 8, 11],
|
52 |
+
"swin2l24_384": [1, 1, 17, 1], # Allowed ranges: [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
53 |
+
"swin2b24_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
54 |
+
"swin2t16_256": [1, 1, 5, 1], # [0, 1], [0, 1], [ 0, 5], [ 0, 1]
|
55 |
+
"swinl12_384": [1, 1, 17, 1], # [0, 1], [0, 1], [ 0, 17], [ 0, 1]
|
56 |
+
"next_vit_large_6m": [2, 6, 36, 39], # [0, 2], [3, 6], [ 7, 36], [37, 39]
|
57 |
+
"levit_384": [3, 11, 21], # [0, 3], [6, 11], [14, 21]
|
58 |
+
"vitb_rn50_384": [0, 1, 8, 11],
|
59 |
+
"vitb16_384": [2, 5, 8, 11],
|
60 |
+
"vitl16_384": [5, 11, 17, 23],
|
61 |
+
}[backbone]
|
62 |
+
|
63 |
+
if "next_vit" in backbone:
|
64 |
+
in_features = {
|
65 |
+
"next_vit_large_6m": [96, 256, 512, 1024],
|
66 |
+
}[backbone]
|
67 |
+
else:
|
68 |
+
in_features = None
|
69 |
+
|
70 |
+
# Instantiate backbone and reassemble blocks
|
71 |
+
self.pretrained, self.scratch = _make_encoder(
|
72 |
+
backbone,
|
73 |
+
features,
|
74 |
+
False, # Set to true of you want to train from scratch, uses ImageNet weights
|
75 |
+
groups=1,
|
76 |
+
expand=False,
|
77 |
+
exportable=False,
|
78 |
+
hooks=hooks,
|
79 |
+
use_readout=readout,
|
80 |
+
in_features=in_features,
|
81 |
+
)
|
82 |
+
|
83 |
+
self.number_layers = len(hooks) if hooks is not None else 4
|
84 |
+
size_refinenet3 = None
|
85 |
+
self.scratch.stem_transpose = None
|
86 |
+
|
87 |
+
if "beit" in backbone:
|
88 |
+
self.forward_transformer = forward_beit
|
89 |
+
elif "swin" in backbone:
|
90 |
+
self.forward_transformer = forward_swin
|
91 |
+
elif "next_vit" in backbone:
|
92 |
+
from .backbones.next_vit import forward_next_vit
|
93 |
+
self.forward_transformer = forward_next_vit
|
94 |
+
elif "levit" in backbone:
|
95 |
+
self.forward_transformer = forward_levit
|
96 |
+
size_refinenet3 = 7
|
97 |
+
self.scratch.stem_transpose = stem_b4_transpose(256, 128, get_act_layer("hard_swish"))
|
98 |
+
else:
|
99 |
+
self.forward_transformer = forward_vit
|
100 |
+
|
101 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
102 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
103 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn, size_refinenet3)
|
104 |
+
if self.number_layers >= 4:
|
105 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
106 |
+
|
107 |
+
self.scratch.output_conv = head
|
108 |
+
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
if self.channels_last == True:
|
112 |
+
x.contiguous(memory_format=torch.channels_last)
|
113 |
+
|
114 |
+
layers = self.forward_transformer(self.pretrained, x)
|
115 |
+
if self.number_layers == 3:
|
116 |
+
layer_1, layer_2, layer_3 = layers
|
117 |
+
else:
|
118 |
+
layer_1, layer_2, layer_3, layer_4 = layers
|
119 |
+
|
120 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
121 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
122 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
123 |
+
if self.number_layers >= 4:
|
124 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
125 |
+
|
126 |
+
if self.number_layers == 3:
|
127 |
+
path_3 = self.scratch.refinenet3(layer_3_rn, size=layer_2_rn.shape[2:])
|
128 |
+
else:
|
129 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
130 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
131 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
132 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
133 |
+
|
134 |
+
if self.scratch.stem_transpose is not None:
|
135 |
+
path_1 = self.scratch.stem_transpose(path_1)
|
136 |
+
|
137 |
+
out = self.scratch.output_conv(path_1)
|
138 |
+
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
class DPTDepthModel(DPT):
|
143 |
+
def __init__(self, path=None, non_negative=True, **kwargs):
|
144 |
+
features = kwargs["features"] if "features" in kwargs else 256
|
145 |
+
head_features_1 = kwargs["head_features_1"] if "head_features_1" in kwargs else features
|
146 |
+
head_features_2 = kwargs["head_features_2"] if "head_features_2" in kwargs else 32
|
147 |
+
kwargs.pop("head_features_1", None)
|
148 |
+
kwargs.pop("head_features_2", None)
|
149 |
+
|
150 |
+
head = nn.Sequential(
|
151 |
+
nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1),
|
152 |
+
Interpolate(scale_factor=2, mode="bilinear", align_corners=True),
|
153 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
154 |
+
nn.ReLU(True),
|
155 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
156 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
157 |
+
nn.Identity(),
|
158 |
+
)
|
159 |
+
|
160 |
+
super().__init__(head, **kwargs)
|
161 |
+
|
162 |
+
if path is not None:
|
163 |
+
self.load(path)
|
164 |
+
|
165 |
+
def forward(self, x):
|
166 |
+
return super().forward(x).squeeze(dim=1)
|
custom_midas_repo/midas/midas_net.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=256, non_negative=True):
|
17 |
+
"""Init.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
path (str, optional): Path to saved model. Defaults to None.
|
21 |
+
features (int, optional): Number of features. Defaults to 256.
|
22 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
23 |
+
"""
|
24 |
+
print("Loading weights: ", path)
|
25 |
+
|
26 |
+
super(MidasNet, self).__init__()
|
27 |
+
|
28 |
+
use_pretrained = False if path is None else True
|
29 |
+
|
30 |
+
self.pretrained, self.scratch = _make_encoder(backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained)
|
31 |
+
|
32 |
+
self.scratch.refinenet4 = FeatureFusionBlock(features)
|
33 |
+
self.scratch.refinenet3 = FeatureFusionBlock(features)
|
34 |
+
self.scratch.refinenet2 = FeatureFusionBlock(features)
|
35 |
+
self.scratch.refinenet1 = FeatureFusionBlock(features)
|
36 |
+
|
37 |
+
self.scratch.output_conv = nn.Sequential(
|
38 |
+
nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1),
|
39 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
40 |
+
nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1),
|
41 |
+
nn.ReLU(True),
|
42 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
43 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
44 |
+
)
|
45 |
+
|
46 |
+
if path:
|
47 |
+
self.load(path)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
"""Forward pass.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
x (tensor): input data (image)
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
tensor: depth
|
57 |
+
"""
|
58 |
+
|
59 |
+
layer_1 = self.pretrained.layer1(x)
|
60 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
61 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
62 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
63 |
+
|
64 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
65 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
66 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
67 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
68 |
+
|
69 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
70 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
71 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
72 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
73 |
+
|
74 |
+
out = self.scratch.output_conv(path_1)
|
75 |
+
|
76 |
+
return torch.squeeze(out, dim=1)
|
custom_midas_repo/midas/midas_net_custom.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""MidashNet: Network for monocular depth estimation trained by mixing several datasets.
|
2 |
+
This file contains code that is adapted from
|
3 |
+
https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py
|
4 |
+
"""
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from .base_model import BaseModel
|
9 |
+
from .blocks import FeatureFusionBlock, FeatureFusionBlock_custom, Interpolate, _make_encoder
|
10 |
+
|
11 |
+
|
12 |
+
class MidasNet_small(BaseModel):
|
13 |
+
"""Network for monocular depth estimation.
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, path=None, features=64, backbone="efficientnet_lite3", non_negative=True, exportable=True, channels_last=False, align_corners=True,
|
17 |
+
blocks={'expand': True}):
|
18 |
+
"""Init.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
path (str, optional): Path to saved model. Defaults to None.
|
22 |
+
features (int, optional): Number of features. Defaults to 256.
|
23 |
+
backbone (str, optional): Backbone network for encoder. Defaults to resnet50
|
24 |
+
"""
|
25 |
+
print("Loading weights: ", path)
|
26 |
+
|
27 |
+
super(MidasNet_small, self).__init__()
|
28 |
+
|
29 |
+
use_pretrained = False if path else True
|
30 |
+
|
31 |
+
self.channels_last = channels_last
|
32 |
+
self.blocks = blocks
|
33 |
+
self.backbone = backbone
|
34 |
+
|
35 |
+
self.groups = 1
|
36 |
+
|
37 |
+
features1=features
|
38 |
+
features2=features
|
39 |
+
features3=features
|
40 |
+
features4=features
|
41 |
+
self.expand = False
|
42 |
+
if "expand" in self.blocks and self.blocks['expand'] == True:
|
43 |
+
self.expand = True
|
44 |
+
features1=features
|
45 |
+
features2=features*2
|
46 |
+
features3=features*4
|
47 |
+
features4=features*8
|
48 |
+
|
49 |
+
self.pretrained, self.scratch = _make_encoder(self.backbone, features, use_pretrained, groups=self.groups, expand=self.expand, exportable=exportable)
|
50 |
+
|
51 |
+
self.scratch.activation = nn.ReLU(False)
|
52 |
+
|
53 |
+
self.scratch.refinenet4 = FeatureFusionBlock_custom(features4, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
54 |
+
self.scratch.refinenet3 = FeatureFusionBlock_custom(features3, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
55 |
+
self.scratch.refinenet2 = FeatureFusionBlock_custom(features2, self.scratch.activation, deconv=False, bn=False, expand=self.expand, align_corners=align_corners)
|
56 |
+
self.scratch.refinenet1 = FeatureFusionBlock_custom(features1, self.scratch.activation, deconv=False, bn=False, align_corners=align_corners)
|
57 |
+
|
58 |
+
|
59 |
+
self.scratch.output_conv = nn.Sequential(
|
60 |
+
nn.Conv2d(features, features//2, kernel_size=3, stride=1, padding=1, groups=self.groups),
|
61 |
+
Interpolate(scale_factor=2, mode="bilinear"),
|
62 |
+
nn.Conv2d(features//2, 32, kernel_size=3, stride=1, padding=1),
|
63 |
+
self.scratch.activation,
|
64 |
+
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
|
65 |
+
nn.ReLU(True) if non_negative else nn.Identity(),
|
66 |
+
nn.Identity(),
|
67 |
+
)
|
68 |
+
|
69 |
+
if path:
|
70 |
+
self.load(path)
|
71 |
+
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
"""Forward pass.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
x (tensor): input data (image)
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
tensor: depth
|
81 |
+
"""
|
82 |
+
if self.channels_last==True:
|
83 |
+
print("self.channels_last = ", self.channels_last)
|
84 |
+
x.contiguous(memory_format=torch.channels_last)
|
85 |
+
|
86 |
+
|
87 |
+
layer_1 = self.pretrained.layer1(x)
|
88 |
+
layer_2 = self.pretrained.layer2(layer_1)
|
89 |
+
layer_3 = self.pretrained.layer3(layer_2)
|
90 |
+
layer_4 = self.pretrained.layer4(layer_3)
|
91 |
+
|
92 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
93 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
94 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
95 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
96 |
+
|
97 |
+
|
98 |
+
path_4 = self.scratch.refinenet4(layer_4_rn)
|
99 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn)
|
100 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn)
|
101 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
102 |
+
|
103 |
+
out = self.scratch.output_conv(path_1)
|
104 |
+
|
105 |
+
return torch.squeeze(out, dim=1)
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def fuse_model(m):
|
110 |
+
prev_previous_type = nn.Identity()
|
111 |
+
prev_previous_name = ''
|
112 |
+
previous_type = nn.Identity()
|
113 |
+
previous_name = ''
|
114 |
+
for name, module in m.named_modules():
|
115 |
+
if prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d and type(module) == nn.ReLU:
|
116 |
+
# print("FUSED ", prev_previous_name, previous_name, name)
|
117 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name, name], inplace=True)
|
118 |
+
elif prev_previous_type == nn.Conv2d and previous_type == nn.BatchNorm2d:
|
119 |
+
# print("FUSED ", prev_previous_name, previous_name)
|
120 |
+
torch.quantization.fuse_modules(m, [prev_previous_name, previous_name], inplace=True)
|
121 |
+
# elif previous_type == nn.Conv2d and type(module) == nn.ReLU:
|
122 |
+
# print("FUSED ", previous_name, name)
|
123 |
+
# torch.quantization.fuse_modules(m, [previous_name, name], inplace=True)
|
124 |
+
|
125 |
+
prev_previous_type = previous_type
|
126 |
+
prev_previous_name = previous_name
|
127 |
+
previous_type = type(module)
|
128 |
+
previous_name = name
|
custom_midas_repo/midas/model_loader.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from custom_midas_repo.midas.dpt_depth import DPTDepthModel
|
5 |
+
from custom_midas_repo.midas.midas_net import MidasNet
|
6 |
+
from custom_midas_repo.midas.midas_net_custom import MidasNet_small
|
7 |
+
from custom_midas_repo.midas.transforms import Resize, NormalizeImage, PrepareForNet
|
8 |
+
|
9 |
+
from torchvision.transforms import Compose
|
10 |
+
|
11 |
+
default_models = {
|
12 |
+
"dpt_beit_large_512": "weights/dpt_beit_large_512.pt",
|
13 |
+
"dpt_beit_large_384": "weights/dpt_beit_large_384.pt",
|
14 |
+
"dpt_beit_base_384": "weights/dpt_beit_base_384.pt",
|
15 |
+
"dpt_swin2_large_384": "weights/dpt_swin2_large_384.pt",
|
16 |
+
"dpt_swin2_base_384": "weights/dpt_swin2_base_384.pt",
|
17 |
+
"dpt_swin2_tiny_256": "weights/dpt_swin2_tiny_256.pt",
|
18 |
+
"dpt_swin_large_384": "weights/dpt_swin_large_384.pt",
|
19 |
+
"dpt_next_vit_large_384": "weights/dpt_next_vit_large_384.pt",
|
20 |
+
"dpt_levit_224": "weights/dpt_levit_224.pt",
|
21 |
+
"dpt_large_384": "weights/dpt_large_384.pt",
|
22 |
+
"dpt_hybrid_384": "weights/dpt_hybrid_384.pt",
|
23 |
+
"midas_v21_384": "weights/midas_v21_384.pt",
|
24 |
+
"midas_v21_small_256": "weights/midas_v21_small_256.pt",
|
25 |
+
"openvino_midas_v21_small_256": "weights/openvino_midas_v21_small_256.xml",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def load_model(device, model_path, model_type="dpt_large_384", optimize=True, height=None, square=False):
|
30 |
+
"""Load the specified network.
|
31 |
+
|
32 |
+
Args:
|
33 |
+
device (device): the torch device used
|
34 |
+
model_path (str): path to saved model
|
35 |
+
model_type (str): the type of the model to be loaded
|
36 |
+
optimize (bool): optimize the model to half-integer on CUDA?
|
37 |
+
height (int): inference encoder image height
|
38 |
+
square (bool): resize to a square resolution?
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
The loaded network, the transform which prepares images as input to the network and the dimensions of the
|
42 |
+
network input
|
43 |
+
"""
|
44 |
+
if "openvino" in model_type:
|
45 |
+
from openvino.runtime import Core
|
46 |
+
|
47 |
+
keep_aspect_ratio = not square
|
48 |
+
|
49 |
+
if model_type == "dpt_beit_large_512":
|
50 |
+
model = DPTDepthModel(
|
51 |
+
path=model_path,
|
52 |
+
backbone="beitl16_512",
|
53 |
+
non_negative=True,
|
54 |
+
)
|
55 |
+
net_w, net_h = 512, 512
|
56 |
+
resize_mode = "minimal"
|
57 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
58 |
+
|
59 |
+
elif model_type == "dpt_beit_large_384":
|
60 |
+
model = DPTDepthModel(
|
61 |
+
path=model_path,
|
62 |
+
backbone="beitl16_384",
|
63 |
+
non_negative=True,
|
64 |
+
)
|
65 |
+
net_w, net_h = 384, 384
|
66 |
+
resize_mode = "minimal"
|
67 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
68 |
+
|
69 |
+
elif model_type == "dpt_beit_base_384":
|
70 |
+
model = DPTDepthModel(
|
71 |
+
path=model_path,
|
72 |
+
backbone="beitb16_384",
|
73 |
+
non_negative=True,
|
74 |
+
)
|
75 |
+
net_w, net_h = 384, 384
|
76 |
+
resize_mode = "minimal"
|
77 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
78 |
+
|
79 |
+
elif model_type == "dpt_swin2_large_384":
|
80 |
+
model = DPTDepthModel(
|
81 |
+
path=model_path,
|
82 |
+
backbone="swin2l24_384",
|
83 |
+
non_negative=True,
|
84 |
+
)
|
85 |
+
net_w, net_h = 384, 384
|
86 |
+
keep_aspect_ratio = False
|
87 |
+
resize_mode = "minimal"
|
88 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
89 |
+
|
90 |
+
elif model_type == "dpt_swin2_base_384":
|
91 |
+
model = DPTDepthModel(
|
92 |
+
path=model_path,
|
93 |
+
backbone="swin2b24_384",
|
94 |
+
non_negative=True,
|
95 |
+
)
|
96 |
+
net_w, net_h = 384, 384
|
97 |
+
keep_aspect_ratio = False
|
98 |
+
resize_mode = "minimal"
|
99 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
100 |
+
|
101 |
+
elif model_type == "dpt_swin2_tiny_256":
|
102 |
+
model = DPTDepthModel(
|
103 |
+
path=model_path,
|
104 |
+
backbone="swin2t16_256",
|
105 |
+
non_negative=True,
|
106 |
+
)
|
107 |
+
net_w, net_h = 256, 256
|
108 |
+
keep_aspect_ratio = False
|
109 |
+
resize_mode = "minimal"
|
110 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
111 |
+
|
112 |
+
elif model_type == "dpt_swin_large_384":
|
113 |
+
model = DPTDepthModel(
|
114 |
+
path=model_path,
|
115 |
+
backbone="swinl12_384",
|
116 |
+
non_negative=True,
|
117 |
+
)
|
118 |
+
net_w, net_h = 384, 384
|
119 |
+
keep_aspect_ratio = False
|
120 |
+
resize_mode = "minimal"
|
121 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
122 |
+
|
123 |
+
elif model_type == "dpt_next_vit_large_384":
|
124 |
+
model = DPTDepthModel(
|
125 |
+
path=model_path,
|
126 |
+
backbone="next_vit_large_6m",
|
127 |
+
non_negative=True,
|
128 |
+
)
|
129 |
+
net_w, net_h = 384, 384
|
130 |
+
resize_mode = "minimal"
|
131 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
132 |
+
|
133 |
+
# We change the notation from dpt_levit_224 (MiDaS notation) to levit_384 (timm notation) here, where the 224 refers
|
134 |
+
# to the resolution 224x224 used by LeViT and 384 is the first entry of the embed_dim, see _cfg and model_cfgs of
|
135 |
+
# https://github.com/rwightman/pytorch-image-models/blob/main/timm/models/levit.py
|
136 |
+
# (commit id: 927f031293a30afb940fff0bee34b85d9c059b0e)
|
137 |
+
elif model_type == "dpt_levit_224":
|
138 |
+
model = DPTDepthModel(
|
139 |
+
path=model_path,
|
140 |
+
backbone="levit_384",
|
141 |
+
non_negative=True,
|
142 |
+
head_features_1=64,
|
143 |
+
head_features_2=8,
|
144 |
+
)
|
145 |
+
net_w, net_h = 224, 224
|
146 |
+
keep_aspect_ratio = False
|
147 |
+
resize_mode = "minimal"
|
148 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
149 |
+
|
150 |
+
elif model_type == "dpt_large_384":
|
151 |
+
model = DPTDepthModel(
|
152 |
+
path=model_path,
|
153 |
+
backbone="vitl16_384",
|
154 |
+
non_negative=True,
|
155 |
+
)
|
156 |
+
net_w, net_h = 384, 384
|
157 |
+
resize_mode = "minimal"
|
158 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
159 |
+
|
160 |
+
elif model_type == "dpt_hybrid_384":
|
161 |
+
model = DPTDepthModel(
|
162 |
+
path=model_path,
|
163 |
+
backbone="vitb_rn50_384",
|
164 |
+
non_negative=True,
|
165 |
+
)
|
166 |
+
net_w, net_h = 384, 384
|
167 |
+
resize_mode = "minimal"
|
168 |
+
normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
169 |
+
|
170 |
+
elif model_type == "midas_v21_384":
|
171 |
+
model = MidasNet(model_path, non_negative=True)
|
172 |
+
net_w, net_h = 384, 384
|
173 |
+
resize_mode = "upper_bound"
|
174 |
+
normalization = NormalizeImage(
|
175 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
176 |
+
)
|
177 |
+
|
178 |
+
elif model_type == "midas_v21_small_256":
|
179 |
+
model = MidasNet_small(model_path, features=64, backbone="efficientnet_lite3", exportable=True,
|
180 |
+
non_negative=True, blocks={'expand': True})
|
181 |
+
net_w, net_h = 256, 256
|
182 |
+
resize_mode = "upper_bound"
|
183 |
+
normalization = NormalizeImage(
|
184 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
185 |
+
)
|
186 |
+
|
187 |
+
elif model_type == "openvino_midas_v21_small_256":
|
188 |
+
ie = Core()
|
189 |
+
uncompiled_model = ie.read_model(model=model_path)
|
190 |
+
model = ie.compile_model(uncompiled_model, "CPU")
|
191 |
+
net_w, net_h = 256, 256
|
192 |
+
resize_mode = "upper_bound"
|
193 |
+
normalization = NormalizeImage(
|
194 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
195 |
+
)
|
196 |
+
|
197 |
+
else:
|
198 |
+
print(f"model_type '{model_type}' not implemented, use: --model_type large")
|
199 |
+
assert False
|
200 |
+
|
201 |
+
if not "openvino" in model_type:
|
202 |
+
print("Model loaded, number of parameters = {:.0f}M".format(sum(p.numel() for p in model.parameters()) / 1e6))
|
203 |
+
else:
|
204 |
+
print("Model loaded, optimized with OpenVINO")
|
205 |
+
|
206 |
+
if "openvino" in model_type:
|
207 |
+
keep_aspect_ratio = False
|
208 |
+
|
209 |
+
if height is not None:
|
210 |
+
net_w, net_h = height, height
|
211 |
+
|
212 |
+
transform = Compose(
|
213 |
+
[
|
214 |
+
Resize(
|
215 |
+
net_w,
|
216 |
+
net_h,
|
217 |
+
resize_target=None,
|
218 |
+
keep_aspect_ratio=keep_aspect_ratio,
|
219 |
+
ensure_multiple_of=32,
|
220 |
+
resize_method=resize_mode,
|
221 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
222 |
+
),
|
223 |
+
normalization,
|
224 |
+
PrepareForNet(),
|
225 |
+
]
|
226 |
+
)
|
227 |
+
|
228 |
+
if not "openvino" in model_type:
|
229 |
+
model.eval()
|
230 |
+
|
231 |
+
if optimize and (device == torch.device("cuda")):
|
232 |
+
if not "openvino" in model_type:
|
233 |
+
model = model.to(memory_format=torch.channels_last)
|
234 |
+
model = model.half()
|
235 |
+
else:
|
236 |
+
print("Error: OpenVINO models are already optimized. No optimization to half-float possible.")
|
237 |
+
exit()
|
238 |
+
|
239 |
+
if not "openvino" in model_type:
|
240 |
+
model.to(device)
|
241 |
+
|
242 |
+
return model, transform, net_w, net_h
|
custom_midas_repo/midas/transforms.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
7 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
8 |
+
|
9 |
+
Args:
|
10 |
+
sample (dict): sample
|
11 |
+
size (tuple): image size
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
tuple: new size
|
15 |
+
"""
|
16 |
+
shape = list(sample["disparity"].shape)
|
17 |
+
|
18 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
19 |
+
return sample
|
20 |
+
|
21 |
+
scale = [0, 0]
|
22 |
+
scale[0] = size[0] / shape[0]
|
23 |
+
scale[1] = size[1] / shape[1]
|
24 |
+
|
25 |
+
scale = max(scale)
|
26 |
+
|
27 |
+
shape[0] = math.ceil(scale * shape[0])
|
28 |
+
shape[1] = math.ceil(scale * shape[1])
|
29 |
+
|
30 |
+
# resize
|
31 |
+
sample["image"] = cv2.resize(
|
32 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
33 |
+
)
|
34 |
+
|
35 |
+
sample["disparity"] = cv2.resize(
|
36 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
37 |
+
)
|
38 |
+
sample["mask"] = cv2.resize(
|
39 |
+
sample["mask"].astype(np.float32),
|
40 |
+
tuple(shape[::-1]),
|
41 |
+
interpolation=cv2.INTER_NEAREST,
|
42 |
+
)
|
43 |
+
sample["mask"] = sample["mask"].astype(bool)
|
44 |
+
|
45 |
+
return tuple(shape)
|
46 |
+
|
47 |
+
|
48 |
+
class Resize(object):
|
49 |
+
"""Resize sample to given size (width, height).
|
50 |
+
"""
|
51 |
+
|
52 |
+
def __init__(
|
53 |
+
self,
|
54 |
+
width,
|
55 |
+
height,
|
56 |
+
resize_target=True,
|
57 |
+
keep_aspect_ratio=False,
|
58 |
+
ensure_multiple_of=1,
|
59 |
+
resize_method="lower_bound",
|
60 |
+
image_interpolation_method=cv2.INTER_AREA,
|
61 |
+
):
|
62 |
+
"""Init.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
width (int): desired output width
|
66 |
+
height (int): desired output height
|
67 |
+
resize_target (bool, optional):
|
68 |
+
True: Resize the full sample (image, mask, target).
|
69 |
+
False: Resize image only.
|
70 |
+
Defaults to True.
|
71 |
+
keep_aspect_ratio (bool, optional):
|
72 |
+
True: Keep the aspect ratio of the input sample.
|
73 |
+
Output sample might not have the given width and height, and
|
74 |
+
resize behaviour depends on the parameter 'resize_method'.
|
75 |
+
Defaults to False.
|
76 |
+
ensure_multiple_of (int, optional):
|
77 |
+
Output width and height is constrained to be multiple of this parameter.
|
78 |
+
Defaults to 1.
|
79 |
+
resize_method (str, optional):
|
80 |
+
"lower_bound": Output will be at least as large as the given size.
|
81 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
82 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
83 |
+
Defaults to "lower_bound".
|
84 |
+
"""
|
85 |
+
self.__width = width
|
86 |
+
self.__height = height
|
87 |
+
|
88 |
+
self.__resize_target = resize_target
|
89 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
90 |
+
self.__multiple_of = ensure_multiple_of
|
91 |
+
self.__resize_method = resize_method
|
92 |
+
self.__image_interpolation_method = image_interpolation_method
|
93 |
+
|
94 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
95 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
96 |
+
|
97 |
+
if max_val is not None and y > max_val:
|
98 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
99 |
+
|
100 |
+
if y < min_val:
|
101 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
102 |
+
|
103 |
+
return y
|
104 |
+
|
105 |
+
def get_size(self, width, height):
|
106 |
+
# determine new height and width
|
107 |
+
scale_height = self.__height / height
|
108 |
+
scale_width = self.__width / width
|
109 |
+
|
110 |
+
if self.__keep_aspect_ratio:
|
111 |
+
if self.__resize_method == "lower_bound":
|
112 |
+
# scale such that output size is lower bound
|
113 |
+
if scale_width > scale_height:
|
114 |
+
# fit width
|
115 |
+
scale_height = scale_width
|
116 |
+
else:
|
117 |
+
# fit height
|
118 |
+
scale_width = scale_height
|
119 |
+
elif self.__resize_method == "upper_bound":
|
120 |
+
# scale such that output size is upper bound
|
121 |
+
if scale_width < scale_height:
|
122 |
+
# fit width
|
123 |
+
scale_height = scale_width
|
124 |
+
else:
|
125 |
+
# fit height
|
126 |
+
scale_width = scale_height
|
127 |
+
elif self.__resize_method == "minimal":
|
128 |
+
# scale as least as possbile
|
129 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
130 |
+
# fit width
|
131 |
+
scale_height = scale_width
|
132 |
+
else:
|
133 |
+
# fit height
|
134 |
+
scale_width = scale_height
|
135 |
+
else:
|
136 |
+
raise ValueError(
|
137 |
+
f"resize_method {self.__resize_method} not implemented"
|
138 |
+
)
|
139 |
+
|
140 |
+
if self.__resize_method == "lower_bound":
|
141 |
+
new_height = self.constrain_to_multiple_of(
|
142 |
+
scale_height * height, min_val=self.__height
|
143 |
+
)
|
144 |
+
new_width = self.constrain_to_multiple_of(
|
145 |
+
scale_width * width, min_val=self.__width
|
146 |
+
)
|
147 |
+
elif self.__resize_method == "upper_bound":
|
148 |
+
new_height = self.constrain_to_multiple_of(
|
149 |
+
scale_height * height, max_val=self.__height
|
150 |
+
)
|
151 |
+
new_width = self.constrain_to_multiple_of(
|
152 |
+
scale_width * width, max_val=self.__width
|
153 |
+
)
|
154 |
+
elif self.__resize_method == "minimal":
|
155 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
156 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
157 |
+
else:
|
158 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
159 |
+
|
160 |
+
return (new_width, new_height)
|
161 |
+
|
162 |
+
def __call__(self, sample):
|
163 |
+
width, height = self.get_size(
|
164 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
165 |
+
)
|
166 |
+
|
167 |
+
# resize sample
|
168 |
+
sample["image"] = cv2.resize(
|
169 |
+
sample["image"],
|
170 |
+
(width, height),
|
171 |
+
interpolation=self.__image_interpolation_method,
|
172 |
+
)
|
173 |
+
|
174 |
+
if self.__resize_target:
|
175 |
+
if "disparity" in sample:
|
176 |
+
sample["disparity"] = cv2.resize(
|
177 |
+
sample["disparity"],
|
178 |
+
(width, height),
|
179 |
+
interpolation=cv2.INTER_NEAREST,
|
180 |
+
)
|
181 |
+
|
182 |
+
if "depth" in sample:
|
183 |
+
sample["depth"] = cv2.resize(
|
184 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
185 |
+
)
|
186 |
+
|
187 |
+
sample["mask"] = cv2.resize(
|
188 |
+
sample["mask"].astype(np.float32),
|
189 |
+
(width, height),
|
190 |
+
interpolation=cv2.INTER_NEAREST,
|
191 |
+
)
|
192 |
+
sample["mask"] = sample["mask"].astype(bool)
|
193 |
+
|
194 |
+
return sample
|
195 |
+
|
196 |
+
|
197 |
+
class NormalizeImage(object):
|
198 |
+
"""Normlize image by given mean and std.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, mean, std):
|
202 |
+
self.__mean = mean
|
203 |
+
self.__std = std
|
204 |
+
|
205 |
+
def __call__(self, sample):
|
206 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
207 |
+
|
208 |
+
return sample
|
209 |
+
|
210 |
+
|
211 |
+
class PrepareForNet(object):
|
212 |
+
"""Prepare sample for usage as network input.
|
213 |
+
"""
|
214 |
+
|
215 |
+
def __init__(self):
|
216 |
+
pass
|
217 |
+
|
218 |
+
def __call__(self, sample):
|
219 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
220 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
221 |
+
|
222 |
+
if "mask" in sample:
|
223 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
224 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
225 |
+
|
226 |
+
if "disparity" in sample:
|
227 |
+
disparity = sample["disparity"].astype(np.float32)
|
228 |
+
sample["disparity"] = np.ascontiguousarray(disparity)
|
229 |
+
|
230 |
+
if "depth" in sample:
|
231 |
+
depth = sample["depth"].astype(np.float32)
|
232 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
233 |
+
|
234 |
+
return sample
|
custom_mmpkg/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#Dummy file ensuring this package will be recognized
|
custom_mmpkg/custom_mmcv/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
# flake8: noqa
|
3 |
+
from .arraymisc import *
|
4 |
+
from .fileio import *
|
5 |
+
from .image import *
|
6 |
+
from .utils import *
|
7 |
+
from .version import *
|
8 |
+
from .video import *
|
9 |
+
from .visualization import *
|
10 |
+
|
11 |
+
# The following modules are not imported to this level, so mmcv may be used
|
12 |
+
# without PyTorch.
|
13 |
+
# - runner
|
14 |
+
# - parallel
|
15 |
+
# - op
|
custom_mmpkg/custom_mmcv/arraymisc/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .quantization import dequantize, quantize
|
3 |
+
|
4 |
+
__all__ = ['quantize', 'dequantize']
|
custom_mmpkg/custom_mmcv/arraymisc/quantization.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
|
5 |
+
def quantize(arr, min_val, max_val, levels, dtype=np.int64):
|
6 |
+
"""Quantize an array of (-inf, inf) to [0, levels-1].
|
7 |
+
|
8 |
+
Args:
|
9 |
+
arr (ndarray): Input array.
|
10 |
+
min_val (scalar): Minimum value to be clipped.
|
11 |
+
max_val (scalar): Maximum value to be clipped.
|
12 |
+
levels (int): Quantization levels.
|
13 |
+
dtype (np.type): The type of the quantized array.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: Quantized array.
|
17 |
+
"""
|
18 |
+
if not (isinstance(levels, int) and levels > 1):
|
19 |
+
raise ValueError(
|
20 |
+
f'levels must be a positive integer, but got {levels}')
|
21 |
+
if min_val >= max_val:
|
22 |
+
raise ValueError(
|
23 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
24 |
+
|
25 |
+
arr = np.clip(arr, min_val, max_val) - min_val
|
26 |
+
quantized_arr = np.minimum(
|
27 |
+
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
|
28 |
+
|
29 |
+
return quantized_arr
|
30 |
+
|
31 |
+
|
32 |
+
def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
|
33 |
+
"""Dequantize an array.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
arr (ndarray): Input array.
|
37 |
+
min_val (scalar): Minimum value to be clipped.
|
38 |
+
max_val (scalar): Maximum value to be clipped.
|
39 |
+
levels (int): Quantization levels.
|
40 |
+
dtype (np.type): The type of the dequantized array.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
tuple: Dequantized array.
|
44 |
+
"""
|
45 |
+
if not (isinstance(levels, int) and levels > 1):
|
46 |
+
raise ValueError(
|
47 |
+
f'levels must be a positive integer, but got {levels}')
|
48 |
+
if min_val >= max_val:
|
49 |
+
raise ValueError(
|
50 |
+
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
|
51 |
+
|
52 |
+
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
|
53 |
+
min_val) / levels + min_val
|
54 |
+
|
55 |
+
return dequantized_arr
|
custom_mmpkg/custom_mmcv/cnn/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .alexnet import AlexNet
|
3 |
+
# yapf: disable
|
4 |
+
from .bricks import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
5 |
+
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS,
|
6 |
+
ContextBlock, Conv2d, Conv3d, ConvAWS2d, ConvModule,
|
7 |
+
ConvTranspose2d, ConvTranspose3d, ConvWS2d,
|
8 |
+
DepthwiseSeparableConvModule, GeneralizedAttention,
|
9 |
+
HSigmoid, HSwish, Linear, MaxPool2d, MaxPool3d,
|
10 |
+
NonLocal1d, NonLocal2d, NonLocal3d, Scale, Swish,
|
11 |
+
build_activation_layer, build_conv_layer,
|
12 |
+
build_norm_layer, build_padding_layer, build_plugin_layer,
|
13 |
+
build_upsample_layer, conv_ws_2d, is_norm)
|
14 |
+
from .builder import MODELS, build_model_from_cfg
|
15 |
+
# yapf: enable
|
16 |
+
from .resnet import ResNet, make_res_layer
|
17 |
+
from .utils import (INITIALIZERS, Caffe2XavierInit, ConstantInit, KaimingInit,
|
18 |
+
NormalInit, PretrainedInit, TruncNormalInit, UniformInit,
|
19 |
+
XavierInit, bias_init_with_prob, caffe2_xavier_init,
|
20 |
+
constant_init, fuse_conv_bn, get_model_complexity_info,
|
21 |
+
initialize, kaiming_init, normal_init, trunc_normal_init,
|
22 |
+
uniform_init, xavier_init)
|
23 |
+
from .vgg import VGG, make_vgg_layer
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
'AlexNet', 'VGG', 'make_vgg_layer', 'ResNet', 'make_res_layer',
|
27 |
+
'constant_init', 'xavier_init', 'normal_init', 'trunc_normal_init',
|
28 |
+
'uniform_init', 'kaiming_init', 'caffe2_xavier_init',
|
29 |
+
'bias_init_with_prob', 'ConvModule', 'build_activation_layer',
|
30 |
+
'build_conv_layer', 'build_norm_layer', 'build_padding_layer',
|
31 |
+
'build_upsample_layer', 'build_plugin_layer', 'is_norm', 'NonLocal1d',
|
32 |
+
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'HSigmoid', 'Swish', 'HSwish',
|
33 |
+
'GeneralizedAttention', 'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS',
|
34 |
+
'PADDING_LAYERS', 'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale',
|
35 |
+
'get_model_complexity_info', 'conv_ws_2d', 'ConvAWS2d', 'ConvWS2d',
|
36 |
+
'fuse_conv_bn', 'DepthwiseSeparableConvModule', 'Linear', 'Conv2d',
|
37 |
+
'ConvTranspose2d', 'MaxPool2d', 'ConvTranspose3d', 'MaxPool3d', 'Conv3d',
|
38 |
+
'initialize', 'INITIALIZERS', 'ConstantInit', 'XavierInit', 'NormalInit',
|
39 |
+
'TruncNormalInit', 'UniformInit', 'KaimingInit', 'PretrainedInit',
|
40 |
+
'Caffe2XavierInit', 'MODELS', 'build_model_from_cfg'
|
41 |
+
]
|
custom_mmpkg/custom_mmcv/cnn/alexnet.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import logging
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
class AlexNet(nn.Module):
|
8 |
+
"""AlexNet backbone.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
num_classes (int): number of classes for classification.
|
12 |
+
"""
|
13 |
+
|
14 |
+
def __init__(self, num_classes=-1):
|
15 |
+
super(AlexNet, self).__init__()
|
16 |
+
self.num_classes = num_classes
|
17 |
+
self.features = nn.Sequential(
|
18 |
+
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
19 |
+
nn.ReLU(inplace=True),
|
20 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
21 |
+
nn.Conv2d(64, 192, kernel_size=5, padding=2),
|
22 |
+
nn.ReLU(inplace=True),
|
23 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
24 |
+
nn.Conv2d(192, 384, kernel_size=3, padding=1),
|
25 |
+
nn.ReLU(inplace=True),
|
26 |
+
nn.Conv2d(384, 256, kernel_size=3, padding=1),
|
27 |
+
nn.ReLU(inplace=True),
|
28 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
29 |
+
nn.ReLU(inplace=True),
|
30 |
+
nn.MaxPool2d(kernel_size=3, stride=2),
|
31 |
+
)
|
32 |
+
if self.num_classes > 0:
|
33 |
+
self.classifier = nn.Sequential(
|
34 |
+
nn.Dropout(),
|
35 |
+
nn.Linear(256 * 6 * 6, 4096),
|
36 |
+
nn.ReLU(inplace=True),
|
37 |
+
nn.Dropout(),
|
38 |
+
nn.Linear(4096, 4096),
|
39 |
+
nn.ReLU(inplace=True),
|
40 |
+
nn.Linear(4096, num_classes),
|
41 |
+
)
|
42 |
+
|
43 |
+
def init_weights(self, pretrained=None):
|
44 |
+
if isinstance(pretrained, str):
|
45 |
+
logger = logging.getLogger()
|
46 |
+
from ..runner import load_checkpoint
|
47 |
+
load_checkpoint(self, pretrained, strict=False, logger=logger)
|
48 |
+
elif pretrained is None:
|
49 |
+
# use default initializer
|
50 |
+
pass
|
51 |
+
else:
|
52 |
+
raise TypeError('pretrained must be a str or None')
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
|
56 |
+
x = self.features(x)
|
57 |
+
if self.num_classes > 0:
|
58 |
+
x = x.view(x.size(0), 256 * 6 * 6)
|
59 |
+
x = self.classifier(x)
|
60 |
+
|
61 |
+
return x
|
custom_mmpkg/custom_mmcv/cnn/bricks/__init__.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from .activation import build_activation_layer
|
3 |
+
from .context_block import ContextBlock
|
4 |
+
from .conv import build_conv_layer
|
5 |
+
from .conv2d_adaptive_padding import Conv2dAdaptivePadding
|
6 |
+
from .conv_module import ConvModule
|
7 |
+
from .conv_ws import ConvAWS2d, ConvWS2d, conv_ws_2d
|
8 |
+
from .depthwise_separable_conv_module import DepthwiseSeparableConvModule
|
9 |
+
from .drop import Dropout, DropPath
|
10 |
+
from .generalized_attention import GeneralizedAttention
|
11 |
+
from .hsigmoid import HSigmoid
|
12 |
+
from .hswish import HSwish
|
13 |
+
from .non_local import NonLocal1d, NonLocal2d, NonLocal3d
|
14 |
+
from .norm import build_norm_layer, is_norm
|
15 |
+
from .padding import build_padding_layer
|
16 |
+
from .plugin import build_plugin_layer
|
17 |
+
from .registry import (ACTIVATION_LAYERS, CONV_LAYERS, NORM_LAYERS,
|
18 |
+
PADDING_LAYERS, PLUGIN_LAYERS, UPSAMPLE_LAYERS)
|
19 |
+
from .scale import Scale
|
20 |
+
from .swish import Swish
|
21 |
+
from .upsample import build_upsample_layer
|
22 |
+
from .wrappers import (Conv2d, Conv3d, ConvTranspose2d, ConvTranspose3d,
|
23 |
+
Linear, MaxPool2d, MaxPool3d)
|
24 |
+
|
25 |
+
__all__ = [
|
26 |
+
'ConvModule', 'build_activation_layer', 'build_conv_layer',
|
27 |
+
'build_norm_layer', 'build_padding_layer', 'build_upsample_layer',
|
28 |
+
'build_plugin_layer', 'is_norm', 'HSigmoid', 'HSwish', 'NonLocal1d',
|
29 |
+
'NonLocal2d', 'NonLocal3d', 'ContextBlock', 'GeneralizedAttention',
|
30 |
+
'ACTIVATION_LAYERS', 'CONV_LAYERS', 'NORM_LAYERS', 'PADDING_LAYERS',
|
31 |
+
'UPSAMPLE_LAYERS', 'PLUGIN_LAYERS', 'Scale', 'ConvAWS2d', 'ConvWS2d',
|
32 |
+
'conv_ws_2d', 'DepthwiseSeparableConvModule', 'Swish', 'Linear',
|
33 |
+
'Conv2dAdaptivePadding', 'Conv2d', 'ConvTranspose2d', 'MaxPool2d',
|
34 |
+
'ConvTranspose3d', 'MaxPool3d', 'Conv3d', 'Dropout', 'DropPath'
|
35 |
+
]
|
custom_mmpkg/custom_mmcv/cnn/bricks/activation.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from custom_mmpkg.custom_mmcv.utils import TORCH_VERSION, build_from_cfg, digit_version
|
7 |
+
from .registry import ACTIVATION_LAYERS
|
8 |
+
|
9 |
+
for module in [
|
10 |
+
nn.ReLU, nn.LeakyReLU, nn.PReLU, nn.RReLU, nn.ReLU6, nn.ELU,
|
11 |
+
nn.Sigmoid, nn.Tanh
|
12 |
+
]:
|
13 |
+
ACTIVATION_LAYERS.register_module(module=module)
|
14 |
+
|
15 |
+
|
16 |
+
@ACTIVATION_LAYERS.register_module(name='Clip')
|
17 |
+
@ACTIVATION_LAYERS.register_module()
|
18 |
+
class Clamp(nn.Module):
|
19 |
+
"""Clamp activation layer.
|
20 |
+
|
21 |
+
This activation function is to clamp the feature map value within
|
22 |
+
:math:`[min, max]`. More details can be found in ``torch.clamp()``.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
min (Number | optional): Lower-bound of the range to be clamped to.
|
26 |
+
Default to -1.
|
27 |
+
max (Number | optional): Upper-bound of the range to be clamped to.
|
28 |
+
Default to 1.
|
29 |
+
"""
|
30 |
+
|
31 |
+
def __init__(self, min=-1., max=1.):
|
32 |
+
super(Clamp, self).__init__()
|
33 |
+
self.min = min
|
34 |
+
self.max = max
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
"""Forward function.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
x (torch.Tensor): The input tensor.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
torch.Tensor: Clamped tensor.
|
44 |
+
"""
|
45 |
+
return torch.clamp(x, min=self.min, max=self.max)
|
46 |
+
|
47 |
+
|
48 |
+
class GELU(nn.Module):
|
49 |
+
r"""Applies the Gaussian Error Linear Units function:
|
50 |
+
|
51 |
+
.. math::
|
52 |
+
\text{GELU}(x) = x * \Phi(x)
|
53 |
+
where :math:`\Phi(x)` is the Cumulative Distribution Function for
|
54 |
+
Gaussian Distribution.
|
55 |
+
|
56 |
+
Shape:
|
57 |
+
- Input: :math:`(N, *)` where `*` means, any number of additional
|
58 |
+
dimensions
|
59 |
+
- Output: :math:`(N, *)`, same shape as the input
|
60 |
+
|
61 |
+
.. image:: scripts/activation_images/GELU.png
|
62 |
+
|
63 |
+
Examples::
|
64 |
+
|
65 |
+
>>> m = nn.GELU()
|
66 |
+
>>> input = torch.randn(2)
|
67 |
+
>>> output = m(input)
|
68 |
+
"""
|
69 |
+
|
70 |
+
def forward(self, input):
|
71 |
+
return F.gelu(input)
|
72 |
+
|
73 |
+
|
74 |
+
if (TORCH_VERSION == 'parrots'
|
75 |
+
or digit_version(TORCH_VERSION) < digit_version('1.4')):
|
76 |
+
ACTIVATION_LAYERS.register_module(module=GELU)
|
77 |
+
else:
|
78 |
+
ACTIVATION_LAYERS.register_module(module=nn.GELU)
|
79 |
+
|
80 |
+
|
81 |
+
def build_activation_layer(cfg):
|
82 |
+
"""Build activation layer.
|
83 |
+
|
84 |
+
Args:
|
85 |
+
cfg (dict): The activation layer config, which should contain:
|
86 |
+
- type (str): Layer type.
|
87 |
+
- layer args: Args needed to instantiate an activation layer.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
nn.Module: Created activation layer.
|
91 |
+
"""
|
92 |
+
return build_from_cfg(cfg, ACTIVATION_LAYERS)
|
custom_mmpkg/custom_mmcv/cnn/bricks/context_block.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from ..utils import constant_init, kaiming_init
|
6 |
+
from .registry import PLUGIN_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def last_zero_init(m):
|
10 |
+
if isinstance(m, nn.Sequential):
|
11 |
+
constant_init(m[-1], val=0)
|
12 |
+
else:
|
13 |
+
constant_init(m, val=0)
|
14 |
+
|
15 |
+
|
16 |
+
@PLUGIN_LAYERS.register_module()
|
17 |
+
class ContextBlock(nn.Module):
|
18 |
+
"""ContextBlock module in GCNet.
|
19 |
+
|
20 |
+
See 'GCNet: Non-local Networks Meet Squeeze-Excitation Networks and Beyond'
|
21 |
+
(https://arxiv.org/abs/1904.11492) for details.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
in_channels (int): Channels of the input feature map.
|
25 |
+
ratio (float): Ratio of channels of transform bottleneck
|
26 |
+
pooling_type (str): Pooling method for context modeling.
|
27 |
+
Options are 'att' and 'avg', stand for attention pooling and
|
28 |
+
average pooling respectively. Default: 'att'.
|
29 |
+
fusion_types (Sequence[str]): Fusion method for feature fusion,
|
30 |
+
Options are 'channels_add', 'channel_mul', stand for channelwise
|
31 |
+
addition and multiplication respectively. Default: ('channel_add',)
|
32 |
+
"""
|
33 |
+
|
34 |
+
_abbr_ = 'context_block'
|
35 |
+
|
36 |
+
def __init__(self,
|
37 |
+
in_channels,
|
38 |
+
ratio,
|
39 |
+
pooling_type='att',
|
40 |
+
fusion_types=('channel_add', )):
|
41 |
+
super(ContextBlock, self).__init__()
|
42 |
+
assert pooling_type in ['avg', 'att']
|
43 |
+
assert isinstance(fusion_types, (list, tuple))
|
44 |
+
valid_fusion_types = ['channel_add', 'channel_mul']
|
45 |
+
assert all([f in valid_fusion_types for f in fusion_types])
|
46 |
+
assert len(fusion_types) > 0, 'at least one fusion should be used'
|
47 |
+
self.in_channels = in_channels
|
48 |
+
self.ratio = ratio
|
49 |
+
self.planes = int(in_channels * ratio)
|
50 |
+
self.pooling_type = pooling_type
|
51 |
+
self.fusion_types = fusion_types
|
52 |
+
if pooling_type == 'att':
|
53 |
+
self.conv_mask = nn.Conv2d(in_channels, 1, kernel_size=1)
|
54 |
+
self.softmax = nn.Softmax(dim=2)
|
55 |
+
else:
|
56 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
57 |
+
if 'channel_add' in fusion_types:
|
58 |
+
self.channel_add_conv = nn.Sequential(
|
59 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
60 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
61 |
+
nn.ReLU(inplace=True), # yapf: disable
|
62 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
63 |
+
else:
|
64 |
+
self.channel_add_conv = None
|
65 |
+
if 'channel_mul' in fusion_types:
|
66 |
+
self.channel_mul_conv = nn.Sequential(
|
67 |
+
nn.Conv2d(self.in_channels, self.planes, kernel_size=1),
|
68 |
+
nn.LayerNorm([self.planes, 1, 1]),
|
69 |
+
nn.ReLU(inplace=True), # yapf: disable
|
70 |
+
nn.Conv2d(self.planes, self.in_channels, kernel_size=1))
|
71 |
+
else:
|
72 |
+
self.channel_mul_conv = None
|
73 |
+
self.reset_parameters()
|
74 |
+
|
75 |
+
def reset_parameters(self):
|
76 |
+
if self.pooling_type == 'att':
|
77 |
+
kaiming_init(self.conv_mask, mode='fan_in')
|
78 |
+
self.conv_mask.inited = True
|
79 |
+
|
80 |
+
if self.channel_add_conv is not None:
|
81 |
+
last_zero_init(self.channel_add_conv)
|
82 |
+
if self.channel_mul_conv is not None:
|
83 |
+
last_zero_init(self.channel_mul_conv)
|
84 |
+
|
85 |
+
def spatial_pool(self, x):
|
86 |
+
batch, channel, height, width = x.size()
|
87 |
+
if self.pooling_type == 'att':
|
88 |
+
input_x = x
|
89 |
+
# [N, C, H * W]
|
90 |
+
input_x = input_x.view(batch, channel, height * width)
|
91 |
+
# [N, 1, C, H * W]
|
92 |
+
input_x = input_x.unsqueeze(1)
|
93 |
+
# [N, 1, H, W]
|
94 |
+
context_mask = self.conv_mask(x)
|
95 |
+
# [N, 1, H * W]
|
96 |
+
context_mask = context_mask.view(batch, 1, height * width)
|
97 |
+
# [N, 1, H * W]
|
98 |
+
context_mask = self.softmax(context_mask)
|
99 |
+
# [N, 1, H * W, 1]
|
100 |
+
context_mask = context_mask.unsqueeze(-1)
|
101 |
+
# [N, 1, C, 1]
|
102 |
+
context = torch.matmul(input_x, context_mask)
|
103 |
+
# [N, C, 1, 1]
|
104 |
+
context = context.view(batch, channel, 1, 1)
|
105 |
+
else:
|
106 |
+
# [N, C, 1, 1]
|
107 |
+
context = self.avg_pool(x)
|
108 |
+
|
109 |
+
return context
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
# [N, C, 1, 1]
|
113 |
+
context = self.spatial_pool(x)
|
114 |
+
|
115 |
+
out = x
|
116 |
+
if self.channel_mul_conv is not None:
|
117 |
+
# [N, C, 1, 1]
|
118 |
+
channel_mul_term = torch.sigmoid(self.channel_mul_conv(context))
|
119 |
+
out = out * channel_mul_term
|
120 |
+
if self.channel_add_conv is not None:
|
121 |
+
# [N, C, 1, 1]
|
122 |
+
channel_add_term = self.channel_add_conv(context)
|
123 |
+
out = out + channel_add_term
|
124 |
+
|
125 |
+
return out
|
custom_mmpkg/custom_mmcv/cnn/bricks/conv.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
from .registry import CONV_LAYERS
|
5 |
+
|
6 |
+
CONV_LAYERS.register_module('Conv1d', module=nn.Conv1d)
|
7 |
+
CONV_LAYERS.register_module('Conv2d', module=nn.Conv2d)
|
8 |
+
CONV_LAYERS.register_module('Conv3d', module=nn.Conv3d)
|
9 |
+
CONV_LAYERS.register_module('Conv', module=nn.Conv2d)
|
10 |
+
|
11 |
+
|
12 |
+
def build_conv_layer(cfg, *args, **kwargs):
|
13 |
+
"""Build convolution layer.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
cfg (None or dict): The conv layer config, which should contain:
|
17 |
+
- type (str): Layer type.
|
18 |
+
- layer args: Args needed to instantiate an conv layer.
|
19 |
+
args (argument list): Arguments passed to the `__init__`
|
20 |
+
method of the corresponding conv layer.
|
21 |
+
kwargs (keyword arguments): Keyword arguments passed to the `__init__`
|
22 |
+
method of the corresponding conv layer.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
nn.Module: Created conv layer.
|
26 |
+
"""
|
27 |
+
if cfg is None:
|
28 |
+
cfg_ = dict(type='Conv2d')
|
29 |
+
else:
|
30 |
+
if not isinstance(cfg, dict):
|
31 |
+
raise TypeError('cfg must be a dict')
|
32 |
+
if 'type' not in cfg:
|
33 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
34 |
+
cfg_ = cfg.copy()
|
35 |
+
|
36 |
+
layer_type = cfg_.pop('type')
|
37 |
+
if layer_type not in CONV_LAYERS:
|
38 |
+
raise KeyError(f'Unrecognized norm type {layer_type}')
|
39 |
+
else:
|
40 |
+
conv_layer = CONV_LAYERS.get(layer_type)
|
41 |
+
|
42 |
+
layer = conv_layer(*args, **kwargs, **cfg_)
|
43 |
+
|
44 |
+
return layer
|
custom_mmpkg/custom_mmcv/cnn/bricks/conv2d_adaptive_padding.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
|
7 |
+
from .registry import CONV_LAYERS
|
8 |
+
|
9 |
+
|
10 |
+
@CONV_LAYERS.register_module()
|
11 |
+
class Conv2dAdaptivePadding(nn.Conv2d):
|
12 |
+
"""Implementation of 2D convolution in tensorflow with `padding` as "same",
|
13 |
+
which applies padding to input (if needed) so that input image gets fully
|
14 |
+
covered by filter and stride you specified. For stride 1, this will ensure
|
15 |
+
that output image size is same as input. For stride of 2, output dimensions
|
16 |
+
will be half, for example.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
in_channels (int): Number of channels in the input image
|
20 |
+
out_channels (int): Number of channels produced by the convolution
|
21 |
+
kernel_size (int or tuple): Size of the convolving kernel
|
22 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
23 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
24 |
+
the input. Default: 0
|
25 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
26 |
+
Default: 1
|
27 |
+
groups (int, optional): Number of blocked connections from input
|
28 |
+
channels to output channels. Default: 1
|
29 |
+
bias (bool, optional): If ``True``, adds a learnable bias to the
|
30 |
+
output. Default: ``True``
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
in_channels,
|
35 |
+
out_channels,
|
36 |
+
kernel_size,
|
37 |
+
stride=1,
|
38 |
+
padding=0,
|
39 |
+
dilation=1,
|
40 |
+
groups=1,
|
41 |
+
bias=True):
|
42 |
+
super().__init__(in_channels, out_channels, kernel_size, stride, 0,
|
43 |
+
dilation, groups, bias)
|
44 |
+
|
45 |
+
def forward(self, x):
|
46 |
+
img_h, img_w = x.size()[-2:]
|
47 |
+
kernel_h, kernel_w = self.weight.size()[-2:]
|
48 |
+
stride_h, stride_w = self.stride
|
49 |
+
output_h = math.ceil(img_h / stride_h)
|
50 |
+
output_w = math.ceil(img_w / stride_w)
|
51 |
+
pad_h = (
|
52 |
+
max((output_h - 1) * self.stride[0] +
|
53 |
+
(kernel_h - 1) * self.dilation[0] + 1 - img_h, 0))
|
54 |
+
pad_w = (
|
55 |
+
max((output_w - 1) * self.stride[1] +
|
56 |
+
(kernel_w - 1) * self.dilation[1] + 1 - img_w, 0))
|
57 |
+
if pad_h > 0 or pad_w > 0:
|
58 |
+
x = F.pad(x, [
|
59 |
+
pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2
|
60 |
+
])
|
61 |
+
return F.conv2d(x, self.weight, self.bias, self.stride, self.padding,
|
62 |
+
self.dilation, self.groups)
|
custom_mmpkg/custom_mmcv/cnn/bricks/conv_module.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import warnings
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from custom_mmpkg.custom_mmcv.utils import _BatchNorm, _InstanceNorm
|
7 |
+
from ..utils import constant_init, kaiming_init
|
8 |
+
from .activation import build_activation_layer
|
9 |
+
from .conv import build_conv_layer
|
10 |
+
from .norm import build_norm_layer
|
11 |
+
from .padding import build_padding_layer
|
12 |
+
from .registry import PLUGIN_LAYERS
|
13 |
+
|
14 |
+
|
15 |
+
@PLUGIN_LAYERS.register_module()
|
16 |
+
class ConvModule(nn.Module):
|
17 |
+
"""A conv block that bundles conv/norm/activation layers.
|
18 |
+
|
19 |
+
This block simplifies the usage of convolution layers, which are commonly
|
20 |
+
used with a norm layer (e.g., BatchNorm) and activation layer (e.g., ReLU).
|
21 |
+
It is based upon three build methods: `build_conv_layer()`,
|
22 |
+
`build_norm_layer()` and `build_activation_layer()`.
|
23 |
+
|
24 |
+
Besides, we add some additional features in this module.
|
25 |
+
1. Automatically set `bias` of the conv layer.
|
26 |
+
2. Spectral norm is supported.
|
27 |
+
3. More padding modes are supported. Before PyTorch 1.5, nn.Conv2d only
|
28 |
+
supports zero and circular padding, and we add "reflect" padding mode.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
in_channels (int): Number of channels in the input feature map.
|
32 |
+
Same as that in ``nn._ConvNd``.
|
33 |
+
out_channels (int): Number of channels produced by the convolution.
|
34 |
+
Same as that in ``nn._ConvNd``.
|
35 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
36 |
+
Same as that in ``nn._ConvNd``.
|
37 |
+
stride (int | tuple[int]): Stride of the convolution.
|
38 |
+
Same as that in ``nn._ConvNd``.
|
39 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
40 |
+
the input. Same as that in ``nn._ConvNd``.
|
41 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
42 |
+
Same as that in ``nn._ConvNd``.
|
43 |
+
groups (int): Number of blocked connections from input channels to
|
44 |
+
output channels. Same as that in ``nn._ConvNd``.
|
45 |
+
bias (bool | str): If specified as `auto`, it will be decided by the
|
46 |
+
norm_cfg. Bias will be set as True if `norm_cfg` is None, otherwise
|
47 |
+
False. Default: "auto".
|
48 |
+
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
49 |
+
which means using conv2d.
|
50 |
+
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
51 |
+
act_cfg (dict): Config dict for activation layer.
|
52 |
+
Default: dict(type='ReLU').
|
53 |
+
inplace (bool): Whether to use inplace mode for activation.
|
54 |
+
Default: True.
|
55 |
+
with_spectral_norm (bool): Whether use spectral norm in conv module.
|
56 |
+
Default: False.
|
57 |
+
padding_mode (str): If the `padding_mode` has not been supported by
|
58 |
+
current `Conv2d` in PyTorch, we will use our own padding layer
|
59 |
+
instead. Currently, we support ['zeros', 'circular'] with official
|
60 |
+
implementation and ['reflect'] with our own implementation.
|
61 |
+
Default: 'zeros'.
|
62 |
+
order (tuple[str]): The order of conv/norm/activation layers. It is a
|
63 |
+
sequence of "conv", "norm" and "act". Common examples are
|
64 |
+
("conv", "norm", "act") and ("act", "conv", "norm").
|
65 |
+
Default: ('conv', 'norm', 'act').
|
66 |
+
"""
|
67 |
+
|
68 |
+
_abbr_ = 'conv_block'
|
69 |
+
|
70 |
+
def __init__(self,
|
71 |
+
in_channels,
|
72 |
+
out_channels,
|
73 |
+
kernel_size,
|
74 |
+
stride=1,
|
75 |
+
padding=0,
|
76 |
+
dilation=1,
|
77 |
+
groups=1,
|
78 |
+
bias='auto',
|
79 |
+
conv_cfg=None,
|
80 |
+
norm_cfg=None,
|
81 |
+
act_cfg=dict(type='ReLU'),
|
82 |
+
inplace=True,
|
83 |
+
with_spectral_norm=False,
|
84 |
+
padding_mode='zeros',
|
85 |
+
order=('conv', 'norm', 'act')):
|
86 |
+
super(ConvModule, self).__init__()
|
87 |
+
assert conv_cfg is None or isinstance(conv_cfg, dict)
|
88 |
+
assert norm_cfg is None or isinstance(norm_cfg, dict)
|
89 |
+
assert act_cfg is None or isinstance(act_cfg, dict)
|
90 |
+
official_padding_mode = ['zeros', 'circular']
|
91 |
+
self.conv_cfg = conv_cfg
|
92 |
+
self.norm_cfg = norm_cfg
|
93 |
+
self.act_cfg = act_cfg
|
94 |
+
self.inplace = inplace
|
95 |
+
self.with_spectral_norm = with_spectral_norm
|
96 |
+
self.with_explicit_padding = padding_mode not in official_padding_mode
|
97 |
+
self.order = order
|
98 |
+
assert isinstance(self.order, tuple) and len(self.order) == 3
|
99 |
+
assert set(order) == set(['conv', 'norm', 'act'])
|
100 |
+
|
101 |
+
self.with_norm = norm_cfg is not None
|
102 |
+
self.with_activation = act_cfg is not None
|
103 |
+
# if the conv layer is before a norm layer, bias is unnecessary.
|
104 |
+
if bias == 'auto':
|
105 |
+
bias = not self.with_norm
|
106 |
+
self.with_bias = bias
|
107 |
+
|
108 |
+
if self.with_explicit_padding:
|
109 |
+
pad_cfg = dict(type=padding_mode)
|
110 |
+
self.padding_layer = build_padding_layer(pad_cfg, padding)
|
111 |
+
|
112 |
+
# reset padding to 0 for conv module
|
113 |
+
conv_padding = 0 if self.with_explicit_padding else padding
|
114 |
+
# build convolution layer
|
115 |
+
self.conv = build_conv_layer(
|
116 |
+
conv_cfg,
|
117 |
+
in_channels,
|
118 |
+
out_channels,
|
119 |
+
kernel_size,
|
120 |
+
stride=stride,
|
121 |
+
padding=conv_padding,
|
122 |
+
dilation=dilation,
|
123 |
+
groups=groups,
|
124 |
+
bias=bias)
|
125 |
+
# export the attributes of self.conv to a higher level for convenience
|
126 |
+
self.in_channels = self.conv.in_channels
|
127 |
+
self.out_channels = self.conv.out_channels
|
128 |
+
self.kernel_size = self.conv.kernel_size
|
129 |
+
self.stride = self.conv.stride
|
130 |
+
self.padding = padding
|
131 |
+
self.dilation = self.conv.dilation
|
132 |
+
self.transposed = self.conv.transposed
|
133 |
+
self.output_padding = self.conv.output_padding
|
134 |
+
self.groups = self.conv.groups
|
135 |
+
|
136 |
+
if self.with_spectral_norm:
|
137 |
+
self.conv = nn.utils.spectral_norm(self.conv)
|
138 |
+
|
139 |
+
# build normalization layers
|
140 |
+
if self.with_norm:
|
141 |
+
# norm layer is after conv layer
|
142 |
+
if order.index('norm') > order.index('conv'):
|
143 |
+
norm_channels = out_channels
|
144 |
+
else:
|
145 |
+
norm_channels = in_channels
|
146 |
+
self.norm_name, norm = build_norm_layer(norm_cfg, norm_channels)
|
147 |
+
self.add_module(self.norm_name, norm)
|
148 |
+
if self.with_bias:
|
149 |
+
if isinstance(norm, (_BatchNorm, _InstanceNorm)):
|
150 |
+
warnings.warn(
|
151 |
+
'Unnecessary conv bias before batch/instance norm')
|
152 |
+
else:
|
153 |
+
self.norm_name = None
|
154 |
+
|
155 |
+
# build activation layer
|
156 |
+
if self.with_activation:
|
157 |
+
act_cfg_ = act_cfg.copy()
|
158 |
+
# nn.Tanh has no 'inplace' argument
|
159 |
+
if act_cfg_['type'] not in [
|
160 |
+
'Tanh', 'PReLU', 'Sigmoid', 'HSigmoid', 'Swish'
|
161 |
+
]:
|
162 |
+
act_cfg_.setdefault('inplace', inplace)
|
163 |
+
self.activate = build_activation_layer(act_cfg_)
|
164 |
+
|
165 |
+
# Use msra init by default
|
166 |
+
self.init_weights()
|
167 |
+
|
168 |
+
@property
|
169 |
+
def norm(self):
|
170 |
+
if self.norm_name:
|
171 |
+
return getattr(self, self.norm_name)
|
172 |
+
else:
|
173 |
+
return None
|
174 |
+
|
175 |
+
def init_weights(self):
|
176 |
+
# 1. It is mainly for customized conv layers with their own
|
177 |
+
# initialization manners by calling their own ``init_weights()``,
|
178 |
+
# and we do not want ConvModule to override the initialization.
|
179 |
+
# 2. For customized conv layers without their own initialization
|
180 |
+
# manners (that is, they don't have their own ``init_weights()``)
|
181 |
+
# and PyTorch's conv layers, they will be initialized by
|
182 |
+
# this method with default ``kaiming_init``.
|
183 |
+
# Note: For PyTorch's conv layers, they will be overwritten by our
|
184 |
+
# initialization implementation using default ``kaiming_init``.
|
185 |
+
if not hasattr(self.conv, 'init_weights'):
|
186 |
+
if self.with_activation and self.act_cfg['type'] == 'LeakyReLU':
|
187 |
+
nonlinearity = 'leaky_relu'
|
188 |
+
a = self.act_cfg.get('negative_slope', 0.01)
|
189 |
+
else:
|
190 |
+
nonlinearity = 'relu'
|
191 |
+
a = 0
|
192 |
+
kaiming_init(self.conv, a=a, nonlinearity=nonlinearity)
|
193 |
+
if self.with_norm:
|
194 |
+
constant_init(self.norm, 1, bias=0)
|
195 |
+
|
196 |
+
def forward(self, x, activate=True, norm=True):
|
197 |
+
for layer in self.order:
|
198 |
+
if layer == 'conv':
|
199 |
+
if self.with_explicit_padding:
|
200 |
+
x = self.padding_layer(x)
|
201 |
+
x = self.conv(x)
|
202 |
+
elif layer == 'norm' and norm and self.with_norm:
|
203 |
+
x = self.norm(x)
|
204 |
+
elif layer == 'act' and activate and self.with_activation:
|
205 |
+
x = self.activate(x)
|
206 |
+
return x
|
custom_mmpkg/custom_mmcv/cnn/bricks/conv_ws.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from .registry import CONV_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def conv_ws_2d(input,
|
10 |
+
weight,
|
11 |
+
bias=None,
|
12 |
+
stride=1,
|
13 |
+
padding=0,
|
14 |
+
dilation=1,
|
15 |
+
groups=1,
|
16 |
+
eps=1e-5):
|
17 |
+
c_in = weight.size(0)
|
18 |
+
weight_flat = weight.view(c_in, -1)
|
19 |
+
mean = weight_flat.mean(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
20 |
+
std = weight_flat.std(dim=1, keepdim=True).view(c_in, 1, 1, 1)
|
21 |
+
weight = (weight - mean) / (std + eps)
|
22 |
+
return F.conv2d(input, weight, bias, stride, padding, dilation, groups)
|
23 |
+
|
24 |
+
|
25 |
+
@CONV_LAYERS.register_module('ConvWS')
|
26 |
+
class ConvWS2d(nn.Conv2d):
|
27 |
+
|
28 |
+
def __init__(self,
|
29 |
+
in_channels,
|
30 |
+
out_channels,
|
31 |
+
kernel_size,
|
32 |
+
stride=1,
|
33 |
+
padding=0,
|
34 |
+
dilation=1,
|
35 |
+
groups=1,
|
36 |
+
bias=True,
|
37 |
+
eps=1e-5):
|
38 |
+
super(ConvWS2d, self).__init__(
|
39 |
+
in_channels,
|
40 |
+
out_channels,
|
41 |
+
kernel_size,
|
42 |
+
stride=stride,
|
43 |
+
padding=padding,
|
44 |
+
dilation=dilation,
|
45 |
+
groups=groups,
|
46 |
+
bias=bias)
|
47 |
+
self.eps = eps
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return conv_ws_2d(x, self.weight, self.bias, self.stride, self.padding,
|
51 |
+
self.dilation, self.groups, self.eps)
|
52 |
+
|
53 |
+
|
54 |
+
@CONV_LAYERS.register_module(name='ConvAWS')
|
55 |
+
class ConvAWS2d(nn.Conv2d):
|
56 |
+
"""AWS (Adaptive Weight Standardization)
|
57 |
+
|
58 |
+
This is a variant of Weight Standardization
|
59 |
+
(https://arxiv.org/pdf/1903.10520.pdf)
|
60 |
+
It is used in DetectoRS to avoid NaN
|
61 |
+
(https://arxiv.org/pdf/2006.02334.pdf)
|
62 |
+
|
63 |
+
Args:
|
64 |
+
in_channels (int): Number of channels in the input image
|
65 |
+
out_channels (int): Number of channels produced by the convolution
|
66 |
+
kernel_size (int or tuple): Size of the conv kernel
|
67 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
68 |
+
padding (int or tuple, optional): Zero-padding added to both sides of
|
69 |
+
the input. Default: 0
|
70 |
+
dilation (int or tuple, optional): Spacing between kernel elements.
|
71 |
+
Default: 1
|
72 |
+
groups (int, optional): Number of blocked connections from input
|
73 |
+
channels to output channels. Default: 1
|
74 |
+
bias (bool, optional): If set True, adds a learnable bias to the
|
75 |
+
output. Default: True
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self,
|
79 |
+
in_channels,
|
80 |
+
out_channels,
|
81 |
+
kernel_size,
|
82 |
+
stride=1,
|
83 |
+
padding=0,
|
84 |
+
dilation=1,
|
85 |
+
groups=1,
|
86 |
+
bias=True):
|
87 |
+
super().__init__(
|
88 |
+
in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size,
|
91 |
+
stride=stride,
|
92 |
+
padding=padding,
|
93 |
+
dilation=dilation,
|
94 |
+
groups=groups,
|
95 |
+
bias=bias)
|
96 |
+
self.register_buffer('weight_gamma',
|
97 |
+
torch.ones(self.out_channels, 1, 1, 1))
|
98 |
+
self.register_buffer('weight_beta',
|
99 |
+
torch.zeros(self.out_channels, 1, 1, 1))
|
100 |
+
|
101 |
+
def _get_weight(self, weight):
|
102 |
+
weight_flat = weight.view(weight.size(0), -1)
|
103 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
104 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
105 |
+
weight = (weight - mean) / std
|
106 |
+
weight = self.weight_gamma * weight + self.weight_beta
|
107 |
+
return weight
|
108 |
+
|
109 |
+
def forward(self, x):
|
110 |
+
weight = self._get_weight(self.weight)
|
111 |
+
return F.conv2d(x, weight, self.bias, self.stride, self.padding,
|
112 |
+
self.dilation, self.groups)
|
113 |
+
|
114 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
115 |
+
missing_keys, unexpected_keys, error_msgs):
|
116 |
+
"""Override default load function.
|
117 |
+
|
118 |
+
AWS overrides the function _load_from_state_dict to recover
|
119 |
+
weight_gamma and weight_beta if they are missing. If weight_gamma and
|
120 |
+
weight_beta are found in the checkpoint, this function will return
|
121 |
+
after super()._load_from_state_dict. Otherwise, it will compute the
|
122 |
+
mean and std of the pretrained weights and store them in weight_beta
|
123 |
+
and weight_gamma.
|
124 |
+
"""
|
125 |
+
|
126 |
+
self.weight_gamma.data.fill_(-1)
|
127 |
+
local_missing_keys = []
|
128 |
+
super()._load_from_state_dict(state_dict, prefix, local_metadata,
|
129 |
+
strict, local_missing_keys,
|
130 |
+
unexpected_keys, error_msgs)
|
131 |
+
if self.weight_gamma.data.mean() > 0:
|
132 |
+
for k in local_missing_keys:
|
133 |
+
missing_keys.append(k)
|
134 |
+
return
|
135 |
+
weight = self.weight.data
|
136 |
+
weight_flat = weight.view(weight.size(0), -1)
|
137 |
+
mean = weight_flat.mean(dim=1).view(-1, 1, 1, 1)
|
138 |
+
std = torch.sqrt(weight_flat.var(dim=1) + 1e-5).view(-1, 1, 1, 1)
|
139 |
+
self.weight_beta.data.copy_(mean)
|
140 |
+
self.weight_gamma.data.copy_(std)
|
141 |
+
missing_gamma_beta = [
|
142 |
+
k for k in local_missing_keys
|
143 |
+
if k.endswith('weight_gamma') or k.endswith('weight_beta')
|
144 |
+
]
|
145 |
+
for k in missing_gamma_beta:
|
146 |
+
local_missing_keys.remove(k)
|
147 |
+
for k in local_missing_keys:
|
148 |
+
missing_keys.append(k)
|
custom_mmpkg/custom_mmcv/cnn/bricks/depthwise_separable_conv_module.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .conv_module import ConvModule
|
5 |
+
|
6 |
+
|
7 |
+
class DepthwiseSeparableConvModule(nn.Module):
|
8 |
+
"""Depthwise separable convolution module.
|
9 |
+
|
10 |
+
See https://arxiv.org/pdf/1704.04861.pdf for details.
|
11 |
+
|
12 |
+
This module can replace a ConvModule with the conv block replaced by two
|
13 |
+
conv block: depthwise conv block and pointwise conv block. The depthwise
|
14 |
+
conv block contains depthwise-conv/norm/activation layers. The pointwise
|
15 |
+
conv block contains pointwise-conv/norm/activation layers. It should be
|
16 |
+
noted that there will be norm/activation layer in the depthwise conv block
|
17 |
+
if `norm_cfg` and `act_cfg` are specified.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): Number of channels in the input feature map.
|
21 |
+
Same as that in ``nn._ConvNd``.
|
22 |
+
out_channels (int): Number of channels produced by the convolution.
|
23 |
+
Same as that in ``nn._ConvNd``.
|
24 |
+
kernel_size (int | tuple[int]): Size of the convolving kernel.
|
25 |
+
Same as that in ``nn._ConvNd``.
|
26 |
+
stride (int | tuple[int]): Stride of the convolution.
|
27 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
28 |
+
padding (int | tuple[int]): Zero-padding added to both sides of
|
29 |
+
the input. Same as that in ``nn._ConvNd``. Default: 0.
|
30 |
+
dilation (int | tuple[int]): Spacing between kernel elements.
|
31 |
+
Same as that in ``nn._ConvNd``. Default: 1.
|
32 |
+
norm_cfg (dict): Default norm config for both depthwise ConvModule and
|
33 |
+
pointwise ConvModule. Default: None.
|
34 |
+
act_cfg (dict): Default activation config for both depthwise ConvModule
|
35 |
+
and pointwise ConvModule. Default: dict(type='ReLU').
|
36 |
+
dw_norm_cfg (dict): Norm config of depthwise ConvModule. If it is
|
37 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
38 |
+
dw_act_cfg (dict): Activation config of depthwise ConvModule. If it is
|
39 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
40 |
+
pw_norm_cfg (dict): Norm config of pointwise ConvModule. If it is
|
41 |
+
'default', it will be the same as `norm_cfg`. Default: 'default'.
|
42 |
+
pw_act_cfg (dict): Activation config of pointwise ConvModule. If it is
|
43 |
+
'default', it will be the same as `act_cfg`. Default: 'default'.
|
44 |
+
kwargs (optional): Other shared arguments for depthwise and pointwise
|
45 |
+
ConvModule. See ConvModule for ref.
|
46 |
+
"""
|
47 |
+
|
48 |
+
def __init__(self,
|
49 |
+
in_channels,
|
50 |
+
out_channels,
|
51 |
+
kernel_size,
|
52 |
+
stride=1,
|
53 |
+
padding=0,
|
54 |
+
dilation=1,
|
55 |
+
norm_cfg=None,
|
56 |
+
act_cfg=dict(type='ReLU'),
|
57 |
+
dw_norm_cfg='default',
|
58 |
+
dw_act_cfg='default',
|
59 |
+
pw_norm_cfg='default',
|
60 |
+
pw_act_cfg='default',
|
61 |
+
**kwargs):
|
62 |
+
super(DepthwiseSeparableConvModule, self).__init__()
|
63 |
+
assert 'groups' not in kwargs, 'groups should not be specified'
|
64 |
+
|
65 |
+
# if norm/activation config of depthwise/pointwise ConvModule is not
|
66 |
+
# specified, use default config.
|
67 |
+
dw_norm_cfg = dw_norm_cfg if dw_norm_cfg != 'default' else norm_cfg
|
68 |
+
dw_act_cfg = dw_act_cfg if dw_act_cfg != 'default' else act_cfg
|
69 |
+
pw_norm_cfg = pw_norm_cfg if pw_norm_cfg != 'default' else norm_cfg
|
70 |
+
pw_act_cfg = pw_act_cfg if pw_act_cfg != 'default' else act_cfg
|
71 |
+
|
72 |
+
# depthwise convolution
|
73 |
+
self.depthwise_conv = ConvModule(
|
74 |
+
in_channels,
|
75 |
+
in_channels,
|
76 |
+
kernel_size,
|
77 |
+
stride=stride,
|
78 |
+
padding=padding,
|
79 |
+
dilation=dilation,
|
80 |
+
groups=in_channels,
|
81 |
+
norm_cfg=dw_norm_cfg,
|
82 |
+
act_cfg=dw_act_cfg,
|
83 |
+
**kwargs)
|
84 |
+
|
85 |
+
self.pointwise_conv = ConvModule(
|
86 |
+
in_channels,
|
87 |
+
out_channels,
|
88 |
+
1,
|
89 |
+
norm_cfg=pw_norm_cfg,
|
90 |
+
act_cfg=pw_act_cfg,
|
91 |
+
**kwargs)
|
92 |
+
|
93 |
+
def forward(self, x):
|
94 |
+
x = self.depthwise_conv(x)
|
95 |
+
x = self.pointwise_conv(x)
|
96 |
+
return x
|
custom_mmpkg/custom_mmcv/cnn/bricks/drop.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from custom_mmpkg.custom_mmcv import build_from_cfg
|
6 |
+
from .registry import DROPOUT_LAYERS
|
7 |
+
|
8 |
+
|
9 |
+
def drop_path(x, drop_prob=0., training=False):
|
10 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
11 |
+
residual blocks).
|
12 |
+
|
13 |
+
We follow the implementation
|
14 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
15 |
+
"""
|
16 |
+
if drop_prob == 0. or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
# handle tensors with different dimensions, not just 4D tensors.
|
20 |
+
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
|
21 |
+
random_tensor = keep_prob + torch.rand(
|
22 |
+
shape, dtype=x.dtype, device=x.device)
|
23 |
+
output = x.div(keep_prob) * random_tensor.floor()
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
@DROPOUT_LAYERS.register_module()
|
28 |
+
class DropPath(nn.Module):
|
29 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of
|
30 |
+
residual blocks).
|
31 |
+
|
32 |
+
We follow the implementation
|
33 |
+
https://github.com/rwightman/pytorch-image-models/blob/a2727c1bf78ba0d7b5727f5f95e37fb7f8866b1f/timm/models/layers/drop.py # noqa: E501
|
34 |
+
|
35 |
+
Args:
|
36 |
+
drop_prob (float): Probability of the path to be zeroed. Default: 0.1
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, drop_prob=0.1):
|
40 |
+
super(DropPath, self).__init__()
|
41 |
+
self.drop_prob = drop_prob
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
return drop_path(x, self.drop_prob, self.training)
|
45 |
+
|
46 |
+
|
47 |
+
@DROPOUT_LAYERS.register_module()
|
48 |
+
class Dropout(nn.Dropout):
|
49 |
+
"""A wrapper for ``torch.nn.Dropout``, We rename the ``p`` of
|
50 |
+
``torch.nn.Dropout`` to ``drop_prob`` so as to be consistent with
|
51 |
+
``DropPath``
|
52 |
+
|
53 |
+
Args:
|
54 |
+
drop_prob (float): Probability of the elements to be
|
55 |
+
zeroed. Default: 0.5.
|
56 |
+
inplace (bool): Do the operation inplace or not. Default: False.
|
57 |
+
"""
|
58 |
+
|
59 |
+
def __init__(self, drop_prob=0.5, inplace=False):
|
60 |
+
super().__init__(p=drop_prob, inplace=inplace)
|
61 |
+
|
62 |
+
|
63 |
+
def build_dropout(cfg, default_args=None):
|
64 |
+
"""Builder for drop out layers."""
|
65 |
+
return build_from_cfg(cfg, DROPOUT_LAYERS, default_args)
|
custom_mmpkg/custom_mmcv/cnn/bricks/generalized_attention.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import math
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from ..utils import kaiming_init
|
10 |
+
from .registry import PLUGIN_LAYERS
|
11 |
+
|
12 |
+
|
13 |
+
@PLUGIN_LAYERS.register_module()
|
14 |
+
class GeneralizedAttention(nn.Module):
|
15 |
+
"""GeneralizedAttention module.
|
16 |
+
|
17 |
+
See 'An Empirical Study of Spatial Attention Mechanisms in Deep Networks'
|
18 |
+
(https://arxiv.org/abs/1711.07971) for details.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Channels of the input feature map.
|
22 |
+
spatial_range (int): The spatial range. -1 indicates no spatial range
|
23 |
+
constraint. Default: -1.
|
24 |
+
num_heads (int): The head number of empirical_attention module.
|
25 |
+
Default: 9.
|
26 |
+
position_embedding_dim (int): The position embedding dimension.
|
27 |
+
Default: -1.
|
28 |
+
position_magnitude (int): A multiplier acting on coord difference.
|
29 |
+
Default: 1.
|
30 |
+
kv_stride (int): The feature stride acting on key/value feature map.
|
31 |
+
Default: 2.
|
32 |
+
q_stride (int): The feature stride acting on query feature map.
|
33 |
+
Default: 1.
|
34 |
+
attention_type (str): A binary indicator string for indicating which
|
35 |
+
items in generalized empirical_attention module are used.
|
36 |
+
Default: '1111'.
|
37 |
+
|
38 |
+
- '1000' indicates 'query and key content' (appr - appr) item,
|
39 |
+
- '0100' indicates 'query content and relative position'
|
40 |
+
(appr - position) item,
|
41 |
+
- '0010' indicates 'key content only' (bias - appr) item,
|
42 |
+
- '0001' indicates 'relative position only' (bias - position) item.
|
43 |
+
"""
|
44 |
+
|
45 |
+
_abbr_ = 'gen_attention_block'
|
46 |
+
|
47 |
+
def __init__(self,
|
48 |
+
in_channels,
|
49 |
+
spatial_range=-1,
|
50 |
+
num_heads=9,
|
51 |
+
position_embedding_dim=-1,
|
52 |
+
position_magnitude=1,
|
53 |
+
kv_stride=2,
|
54 |
+
q_stride=1,
|
55 |
+
attention_type='1111'):
|
56 |
+
|
57 |
+
super(GeneralizedAttention, self).__init__()
|
58 |
+
|
59 |
+
# hard range means local range for non-local operation
|
60 |
+
self.position_embedding_dim = (
|
61 |
+
position_embedding_dim
|
62 |
+
if position_embedding_dim > 0 else in_channels)
|
63 |
+
|
64 |
+
self.position_magnitude = position_magnitude
|
65 |
+
self.num_heads = num_heads
|
66 |
+
self.in_channels = in_channels
|
67 |
+
self.spatial_range = spatial_range
|
68 |
+
self.kv_stride = kv_stride
|
69 |
+
self.q_stride = q_stride
|
70 |
+
self.attention_type = [bool(int(_)) for _ in attention_type]
|
71 |
+
self.qk_embed_dim = in_channels // num_heads
|
72 |
+
out_c = self.qk_embed_dim * num_heads
|
73 |
+
|
74 |
+
if self.attention_type[0] or self.attention_type[1]:
|
75 |
+
self.query_conv = nn.Conv2d(
|
76 |
+
in_channels=in_channels,
|
77 |
+
out_channels=out_c,
|
78 |
+
kernel_size=1,
|
79 |
+
bias=False)
|
80 |
+
self.query_conv.kaiming_init = True
|
81 |
+
|
82 |
+
if self.attention_type[0] or self.attention_type[2]:
|
83 |
+
self.key_conv = nn.Conv2d(
|
84 |
+
in_channels=in_channels,
|
85 |
+
out_channels=out_c,
|
86 |
+
kernel_size=1,
|
87 |
+
bias=False)
|
88 |
+
self.key_conv.kaiming_init = True
|
89 |
+
|
90 |
+
self.v_dim = in_channels // num_heads
|
91 |
+
self.value_conv = nn.Conv2d(
|
92 |
+
in_channels=in_channels,
|
93 |
+
out_channels=self.v_dim * num_heads,
|
94 |
+
kernel_size=1,
|
95 |
+
bias=False)
|
96 |
+
self.value_conv.kaiming_init = True
|
97 |
+
|
98 |
+
if self.attention_type[1] or self.attention_type[3]:
|
99 |
+
self.appr_geom_fc_x = nn.Linear(
|
100 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
101 |
+
self.appr_geom_fc_x.kaiming_init = True
|
102 |
+
|
103 |
+
self.appr_geom_fc_y = nn.Linear(
|
104 |
+
self.position_embedding_dim // 2, out_c, bias=False)
|
105 |
+
self.appr_geom_fc_y.kaiming_init = True
|
106 |
+
|
107 |
+
if self.attention_type[2]:
|
108 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
109 |
+
appr_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
110 |
+
self.appr_bias = nn.Parameter(appr_bias_value)
|
111 |
+
|
112 |
+
if self.attention_type[3]:
|
113 |
+
stdv = 1.0 / math.sqrt(self.qk_embed_dim * 2)
|
114 |
+
geom_bias_value = -2 * stdv * torch.rand(out_c) + stdv
|
115 |
+
self.geom_bias = nn.Parameter(geom_bias_value)
|
116 |
+
|
117 |
+
self.proj_conv = nn.Conv2d(
|
118 |
+
in_channels=self.v_dim * num_heads,
|
119 |
+
out_channels=in_channels,
|
120 |
+
kernel_size=1,
|
121 |
+
bias=True)
|
122 |
+
self.proj_conv.kaiming_init = True
|
123 |
+
self.gamma = nn.Parameter(torch.zeros(1))
|
124 |
+
|
125 |
+
if self.spatial_range >= 0:
|
126 |
+
# only works when non local is after 3*3 conv
|
127 |
+
if in_channels == 256:
|
128 |
+
max_len = 84
|
129 |
+
elif in_channels == 512:
|
130 |
+
max_len = 42
|
131 |
+
|
132 |
+
max_len_kv = int((max_len - 1.0) / self.kv_stride + 1)
|
133 |
+
local_constraint_map = np.ones(
|
134 |
+
(max_len, max_len, max_len_kv, max_len_kv), dtype=np.int)
|
135 |
+
for iy in range(max_len):
|
136 |
+
for ix in range(max_len):
|
137 |
+
local_constraint_map[
|
138 |
+
iy, ix,
|
139 |
+
max((iy - self.spatial_range) //
|
140 |
+
self.kv_stride, 0):min((iy + self.spatial_range +
|
141 |
+
1) // self.kv_stride +
|
142 |
+
1, max_len),
|
143 |
+
max((ix - self.spatial_range) //
|
144 |
+
self.kv_stride, 0):min((ix + self.spatial_range +
|
145 |
+
1) // self.kv_stride +
|
146 |
+
1, max_len)] = 0
|
147 |
+
|
148 |
+
self.local_constraint_map = nn.Parameter(
|
149 |
+
torch.from_numpy(local_constraint_map).byte(),
|
150 |
+
requires_grad=False)
|
151 |
+
|
152 |
+
if self.q_stride > 1:
|
153 |
+
self.q_downsample = nn.AvgPool2d(
|
154 |
+
kernel_size=1, stride=self.q_stride)
|
155 |
+
else:
|
156 |
+
self.q_downsample = None
|
157 |
+
|
158 |
+
if self.kv_stride > 1:
|
159 |
+
self.kv_downsample = nn.AvgPool2d(
|
160 |
+
kernel_size=1, stride=self.kv_stride)
|
161 |
+
else:
|
162 |
+
self.kv_downsample = None
|
163 |
+
|
164 |
+
self.init_weights()
|
165 |
+
|
166 |
+
def get_position_embedding(self,
|
167 |
+
h,
|
168 |
+
w,
|
169 |
+
h_kv,
|
170 |
+
w_kv,
|
171 |
+
q_stride,
|
172 |
+
kv_stride,
|
173 |
+
device,
|
174 |
+
dtype,
|
175 |
+
feat_dim,
|
176 |
+
wave_length=1000):
|
177 |
+
# the default type of Tensor is float32, leading to type mismatch
|
178 |
+
# in fp16 mode. Cast it to support fp16 mode.
|
179 |
+
h_idxs = torch.linspace(0, h - 1, h).to(device=device, dtype=dtype)
|
180 |
+
h_idxs = h_idxs.view((h, 1)) * q_stride
|
181 |
+
|
182 |
+
w_idxs = torch.linspace(0, w - 1, w).to(device=device, dtype=dtype)
|
183 |
+
w_idxs = w_idxs.view((w, 1)) * q_stride
|
184 |
+
|
185 |
+
h_kv_idxs = torch.linspace(0, h_kv - 1, h_kv).to(
|
186 |
+
device=device, dtype=dtype)
|
187 |
+
h_kv_idxs = h_kv_idxs.view((h_kv, 1)) * kv_stride
|
188 |
+
|
189 |
+
w_kv_idxs = torch.linspace(0, w_kv - 1, w_kv).to(
|
190 |
+
device=device, dtype=dtype)
|
191 |
+
w_kv_idxs = w_kv_idxs.view((w_kv, 1)) * kv_stride
|
192 |
+
|
193 |
+
# (h, h_kv, 1)
|
194 |
+
h_diff = h_idxs.unsqueeze(1) - h_kv_idxs.unsqueeze(0)
|
195 |
+
h_diff *= self.position_magnitude
|
196 |
+
|
197 |
+
# (w, w_kv, 1)
|
198 |
+
w_diff = w_idxs.unsqueeze(1) - w_kv_idxs.unsqueeze(0)
|
199 |
+
w_diff *= self.position_magnitude
|
200 |
+
|
201 |
+
feat_range = torch.arange(0, feat_dim / 4).to(
|
202 |
+
device=device, dtype=dtype)
|
203 |
+
|
204 |
+
dim_mat = torch.Tensor([wave_length]).to(device=device, dtype=dtype)
|
205 |
+
dim_mat = dim_mat**((4. / feat_dim) * feat_range)
|
206 |
+
dim_mat = dim_mat.view((1, 1, -1))
|
207 |
+
|
208 |
+
embedding_x = torch.cat(
|
209 |
+
((w_diff / dim_mat).sin(), (w_diff / dim_mat).cos()), dim=2)
|
210 |
+
|
211 |
+
embedding_y = torch.cat(
|
212 |
+
((h_diff / dim_mat).sin(), (h_diff / dim_mat).cos()), dim=2)
|
213 |
+
|
214 |
+
return embedding_x, embedding_y
|
215 |
+
|
216 |
+
def forward(self, x_input):
|
217 |
+
num_heads = self.num_heads
|
218 |
+
|
219 |
+
# use empirical_attention
|
220 |
+
if self.q_downsample is not None:
|
221 |
+
x_q = self.q_downsample(x_input)
|
222 |
+
else:
|
223 |
+
x_q = x_input
|
224 |
+
n, _, h, w = x_q.shape
|
225 |
+
|
226 |
+
if self.kv_downsample is not None:
|
227 |
+
x_kv = self.kv_downsample(x_input)
|
228 |
+
else:
|
229 |
+
x_kv = x_input
|
230 |
+
_, _, h_kv, w_kv = x_kv.shape
|
231 |
+
|
232 |
+
if self.attention_type[0] or self.attention_type[1]:
|
233 |
+
proj_query = self.query_conv(x_q).view(
|
234 |
+
(n, num_heads, self.qk_embed_dim, h * w))
|
235 |
+
proj_query = proj_query.permute(0, 1, 3, 2)
|
236 |
+
|
237 |
+
if self.attention_type[0] or self.attention_type[2]:
|
238 |
+
proj_key = self.key_conv(x_kv).view(
|
239 |
+
(n, num_heads, self.qk_embed_dim, h_kv * w_kv))
|
240 |
+
|
241 |
+
if self.attention_type[1] or self.attention_type[3]:
|
242 |
+
position_embed_x, position_embed_y = self.get_position_embedding(
|
243 |
+
h, w, h_kv, w_kv, self.q_stride, self.kv_stride,
|
244 |
+
x_input.device, x_input.dtype, self.position_embedding_dim)
|
245 |
+
# (n, num_heads, w, w_kv, dim)
|
246 |
+
position_feat_x = self.appr_geom_fc_x(position_embed_x).\
|
247 |
+
view(1, w, w_kv, num_heads, self.qk_embed_dim).\
|
248 |
+
permute(0, 3, 1, 2, 4).\
|
249 |
+
repeat(n, 1, 1, 1, 1)
|
250 |
+
|
251 |
+
# (n, num_heads, h, h_kv, dim)
|
252 |
+
position_feat_y = self.appr_geom_fc_y(position_embed_y).\
|
253 |
+
view(1, h, h_kv, num_heads, self.qk_embed_dim).\
|
254 |
+
permute(0, 3, 1, 2, 4).\
|
255 |
+
repeat(n, 1, 1, 1, 1)
|
256 |
+
|
257 |
+
position_feat_x /= math.sqrt(2)
|
258 |
+
position_feat_y /= math.sqrt(2)
|
259 |
+
|
260 |
+
# accelerate for saliency only
|
261 |
+
if (np.sum(self.attention_type) == 1) and self.attention_type[2]:
|
262 |
+
appr_bias = self.appr_bias.\
|
263 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
264 |
+
repeat(n, 1, 1, 1)
|
265 |
+
|
266 |
+
energy = torch.matmul(appr_bias, proj_key).\
|
267 |
+
view(n, num_heads, 1, h_kv * w_kv)
|
268 |
+
|
269 |
+
h = 1
|
270 |
+
w = 1
|
271 |
+
else:
|
272 |
+
# (n, num_heads, h*w, h_kv*w_kv), query before key, 540mb for
|
273 |
+
if not self.attention_type[0]:
|
274 |
+
energy = torch.zeros(
|
275 |
+
n,
|
276 |
+
num_heads,
|
277 |
+
h,
|
278 |
+
w,
|
279 |
+
h_kv,
|
280 |
+
w_kv,
|
281 |
+
dtype=x_input.dtype,
|
282 |
+
device=x_input.device)
|
283 |
+
|
284 |
+
# attention_type[0]: appr - appr
|
285 |
+
# attention_type[1]: appr - position
|
286 |
+
# attention_type[2]: bias - appr
|
287 |
+
# attention_type[3]: bias - position
|
288 |
+
if self.attention_type[0] or self.attention_type[2]:
|
289 |
+
if self.attention_type[0] and self.attention_type[2]:
|
290 |
+
appr_bias = self.appr_bias.\
|
291 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
292 |
+
energy = torch.matmul(proj_query + appr_bias, proj_key).\
|
293 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
294 |
+
|
295 |
+
elif self.attention_type[0]:
|
296 |
+
energy = torch.matmul(proj_query, proj_key).\
|
297 |
+
view(n, num_heads, h, w, h_kv, w_kv)
|
298 |
+
|
299 |
+
elif self.attention_type[2]:
|
300 |
+
appr_bias = self.appr_bias.\
|
301 |
+
view(1, num_heads, 1, self.qk_embed_dim).\
|
302 |
+
repeat(n, 1, 1, 1)
|
303 |
+
|
304 |
+
energy += torch.matmul(appr_bias, proj_key).\
|
305 |
+
view(n, num_heads, 1, 1, h_kv, w_kv)
|
306 |
+
|
307 |
+
if self.attention_type[1] or self.attention_type[3]:
|
308 |
+
if self.attention_type[1] and self.attention_type[3]:
|
309 |
+
geom_bias = self.geom_bias.\
|
310 |
+
view(1, num_heads, 1, self.qk_embed_dim)
|
311 |
+
|
312 |
+
proj_query_reshape = (proj_query + geom_bias).\
|
313 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
314 |
+
|
315 |
+
energy_x = torch.matmul(
|
316 |
+
proj_query_reshape.permute(0, 1, 3, 2, 4),
|
317 |
+
position_feat_x.permute(0, 1, 2, 4, 3))
|
318 |
+
energy_x = energy_x.\
|
319 |
+
permute(0, 1, 3, 2, 4).unsqueeze(4)
|
320 |
+
|
321 |
+
energy_y = torch.matmul(
|
322 |
+
proj_query_reshape,
|
323 |
+
position_feat_y.permute(0, 1, 2, 4, 3))
|
324 |
+
energy_y = energy_y.unsqueeze(5)
|
325 |
+
|
326 |
+
energy += energy_x + energy_y
|
327 |
+
|
328 |
+
elif self.attention_type[1]:
|
329 |
+
proj_query_reshape = proj_query.\
|
330 |
+
view(n, num_heads, h, w, self.qk_embed_dim)
|
331 |
+
proj_query_reshape = proj_query_reshape.\
|
332 |
+
permute(0, 1, 3, 2, 4)
|
333 |
+
position_feat_x_reshape = position_feat_x.\
|
334 |
+
permute(0, 1, 2, 4, 3)
|
335 |
+
position_feat_y_reshape = position_feat_y.\
|
336 |
+
permute(0, 1, 2, 4, 3)
|
337 |
+
|
338 |
+
energy_x = torch.matmul(proj_query_reshape,
|
339 |
+
position_feat_x_reshape)
|
340 |
+
energy_x = energy_x.permute(0, 1, 3, 2, 4).unsqueeze(4)
|
341 |
+
|
342 |
+
energy_y = torch.matmul(proj_query_reshape,
|
343 |
+
position_feat_y_reshape)
|
344 |
+
energy_y = energy_y.unsqueeze(5)
|
345 |
+
|
346 |
+
energy += energy_x + energy_y
|
347 |
+
|
348 |
+
elif self.attention_type[3]:
|
349 |
+
geom_bias = self.geom_bias.\
|
350 |
+
view(1, num_heads, self.qk_embed_dim, 1).\
|
351 |
+
repeat(n, 1, 1, 1)
|
352 |
+
|
353 |
+
position_feat_x_reshape = position_feat_x.\
|
354 |
+
view(n, num_heads, w*w_kv, self.qk_embed_dim)
|
355 |
+
|
356 |
+
position_feat_y_reshape = position_feat_y.\
|
357 |
+
view(n, num_heads, h * h_kv, self.qk_embed_dim)
|
358 |
+
|
359 |
+
energy_x = torch.matmul(position_feat_x_reshape, geom_bias)
|
360 |
+
energy_x = energy_x.view(n, num_heads, 1, w, 1, w_kv)
|
361 |
+
|
362 |
+
energy_y = torch.matmul(position_feat_y_reshape, geom_bias)
|
363 |
+
energy_y = energy_y.view(n, num_heads, h, 1, h_kv, 1)
|
364 |
+
|
365 |
+
energy += energy_x + energy_y
|
366 |
+
|
367 |
+
energy = energy.view(n, num_heads, h * w, h_kv * w_kv)
|
368 |
+
|
369 |
+
if self.spatial_range >= 0:
|
370 |
+
cur_local_constraint_map = \
|
371 |
+
self.local_constraint_map[:h, :w, :h_kv, :w_kv].\
|
372 |
+
contiguous().\
|
373 |
+
view(1, 1, h*w, h_kv*w_kv)
|
374 |
+
|
375 |
+
energy = energy.masked_fill_(cur_local_constraint_map,
|
376 |
+
float('-inf'))
|
377 |
+
|
378 |
+
attention = F.softmax(energy, 3)
|
379 |
+
|
380 |
+
proj_value = self.value_conv(x_kv)
|
381 |
+
proj_value_reshape = proj_value.\
|
382 |
+
view((n, num_heads, self.v_dim, h_kv * w_kv)).\
|
383 |
+
permute(0, 1, 3, 2)
|
384 |
+
|
385 |
+
out = torch.matmul(attention, proj_value_reshape).\
|
386 |
+
permute(0, 1, 3, 2).\
|
387 |
+
contiguous().\
|
388 |
+
view(n, self.v_dim * self.num_heads, h, w)
|
389 |
+
|
390 |
+
out = self.proj_conv(out)
|
391 |
+
|
392 |
+
# output is downsampled, upsample back to input size
|
393 |
+
if self.q_downsample is not None:
|
394 |
+
out = F.interpolate(
|
395 |
+
out,
|
396 |
+
size=x_input.shape[2:],
|
397 |
+
mode='bilinear',
|
398 |
+
align_corners=False)
|
399 |
+
|
400 |
+
out = self.gamma * out + x_input
|
401 |
+
return out
|
402 |
+
|
403 |
+
def init_weights(self):
|
404 |
+
for m in self.modules():
|
405 |
+
if hasattr(m, 'kaiming_init') and m.kaiming_init:
|
406 |
+
kaiming_init(
|
407 |
+
m,
|
408 |
+
mode='fan_in',
|
409 |
+
nonlinearity='leaky_relu',
|
410 |
+
bias=0,
|
411 |
+
distribution='uniform',
|
412 |
+
a=1)
|
custom_mmpkg/custom_mmcv/cnn/bricks/hsigmoid.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import ACTIVATION_LAYERS
|
5 |
+
|
6 |
+
|
7 |
+
@ACTIVATION_LAYERS.register_module()
|
8 |
+
class HSigmoid(nn.Module):
|
9 |
+
"""Hard Sigmoid Module. Apply the hard sigmoid function:
|
10 |
+
Hsigmoid(x) = min(max((x + bias) / divisor, min_value), max_value)
|
11 |
+
Default: Hsigmoid(x) = min(max((x + 1) / 2, 0), 1)
|
12 |
+
|
13 |
+
Args:
|
14 |
+
bias (float): Bias of the input feature map. Default: 1.0.
|
15 |
+
divisor (float): Divisor of the input feature map. Default: 2.0.
|
16 |
+
min_value (float): Lower bound value. Default: 0.0.
|
17 |
+
max_value (float): Upper bound value. Default: 1.0.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor: The output tensor.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, bias=1.0, divisor=2.0, min_value=0.0, max_value=1.0):
|
24 |
+
super(HSigmoid, self).__init__()
|
25 |
+
self.bias = bias
|
26 |
+
self.divisor = divisor
|
27 |
+
assert self.divisor != 0
|
28 |
+
self.min_value = min_value
|
29 |
+
self.max_value = max_value
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = (x + self.bias) / self.divisor
|
33 |
+
|
34 |
+
return x.clamp_(self.min_value, self.max_value)
|
custom_mmpkg/custom_mmcv/cnn/bricks/hswish.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import ACTIVATION_LAYERS
|
5 |
+
|
6 |
+
|
7 |
+
@ACTIVATION_LAYERS.register_module()
|
8 |
+
class HSwish(nn.Module):
|
9 |
+
"""Hard Swish Module.
|
10 |
+
|
11 |
+
This module applies the hard swish function:
|
12 |
+
|
13 |
+
.. math::
|
14 |
+
Hswish(x) = x * ReLU6(x + 3) / 6
|
15 |
+
|
16 |
+
Args:
|
17 |
+
inplace (bool): can optionally do the operation in-place.
|
18 |
+
Default: False.
|
19 |
+
|
20 |
+
Returns:
|
21 |
+
Tensor: The output tensor.
|
22 |
+
"""
|
23 |
+
|
24 |
+
def __init__(self, inplace=False):
|
25 |
+
super(HSwish, self).__init__()
|
26 |
+
self.act = nn.ReLU6(inplace)
|
27 |
+
|
28 |
+
def forward(self, x):
|
29 |
+
return x * self.act(x + 3) / 6
|
custom_mmpkg/custom_mmcv/cnn/bricks/non_local.py
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from abc import ABCMeta
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from ..utils import constant_init, normal_init
|
8 |
+
from .conv_module import ConvModule
|
9 |
+
from .registry import PLUGIN_LAYERS
|
10 |
+
|
11 |
+
|
12 |
+
class _NonLocalNd(nn.Module, metaclass=ABCMeta):
|
13 |
+
"""Basic Non-local module.
|
14 |
+
|
15 |
+
This module is proposed in
|
16 |
+
"Non-local Neural Networks"
|
17 |
+
Paper reference: https://arxiv.org/abs/1711.07971
|
18 |
+
Code reference: https://github.com/AlexHex7/Non-local_pytorch
|
19 |
+
|
20 |
+
Args:
|
21 |
+
in_channels (int): Channels of the input feature map.
|
22 |
+
reduction (int): Channel reduction ratio. Default: 2.
|
23 |
+
use_scale (bool): Whether to scale pairwise_weight by
|
24 |
+
`1/sqrt(inter_channels)` when the mode is `embedded_gaussian`.
|
25 |
+
Default: True.
|
26 |
+
conv_cfg (None | dict): The config dict for convolution layers.
|
27 |
+
If not specified, it will use `nn.Conv2d` for convolution layers.
|
28 |
+
Default: None.
|
29 |
+
norm_cfg (None | dict): The config dict for normalization layers.
|
30 |
+
Default: None. (This parameter is only applicable to conv_out.)
|
31 |
+
mode (str): Options are `gaussian`, `concatenation`,
|
32 |
+
`embedded_gaussian` and `dot_product`. Default: embedded_gaussian.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self,
|
36 |
+
in_channels,
|
37 |
+
reduction=2,
|
38 |
+
use_scale=True,
|
39 |
+
conv_cfg=None,
|
40 |
+
norm_cfg=None,
|
41 |
+
mode='embedded_gaussian',
|
42 |
+
**kwargs):
|
43 |
+
super(_NonLocalNd, self).__init__()
|
44 |
+
self.in_channels = in_channels
|
45 |
+
self.reduction = reduction
|
46 |
+
self.use_scale = use_scale
|
47 |
+
self.inter_channels = max(in_channels // reduction, 1)
|
48 |
+
self.mode = mode
|
49 |
+
|
50 |
+
if mode not in [
|
51 |
+
'gaussian', 'embedded_gaussian', 'dot_product', 'concatenation'
|
52 |
+
]:
|
53 |
+
raise ValueError("Mode should be in 'gaussian', 'concatenation', "
|
54 |
+
f"'embedded_gaussian' or 'dot_product', but got "
|
55 |
+
f'{mode} instead.')
|
56 |
+
|
57 |
+
# g, theta, phi are defaulted as `nn.ConvNd`.
|
58 |
+
# Here we use ConvModule for potential usage.
|
59 |
+
self.g = ConvModule(
|
60 |
+
self.in_channels,
|
61 |
+
self.inter_channels,
|
62 |
+
kernel_size=1,
|
63 |
+
conv_cfg=conv_cfg,
|
64 |
+
act_cfg=None)
|
65 |
+
self.conv_out = ConvModule(
|
66 |
+
self.inter_channels,
|
67 |
+
self.in_channels,
|
68 |
+
kernel_size=1,
|
69 |
+
conv_cfg=conv_cfg,
|
70 |
+
norm_cfg=norm_cfg,
|
71 |
+
act_cfg=None)
|
72 |
+
|
73 |
+
if self.mode != 'gaussian':
|
74 |
+
self.theta = ConvModule(
|
75 |
+
self.in_channels,
|
76 |
+
self.inter_channels,
|
77 |
+
kernel_size=1,
|
78 |
+
conv_cfg=conv_cfg,
|
79 |
+
act_cfg=None)
|
80 |
+
self.phi = ConvModule(
|
81 |
+
self.in_channels,
|
82 |
+
self.inter_channels,
|
83 |
+
kernel_size=1,
|
84 |
+
conv_cfg=conv_cfg,
|
85 |
+
act_cfg=None)
|
86 |
+
|
87 |
+
if self.mode == 'concatenation':
|
88 |
+
self.concat_project = ConvModule(
|
89 |
+
self.inter_channels * 2,
|
90 |
+
1,
|
91 |
+
kernel_size=1,
|
92 |
+
stride=1,
|
93 |
+
padding=0,
|
94 |
+
bias=False,
|
95 |
+
act_cfg=dict(type='ReLU'))
|
96 |
+
|
97 |
+
self.init_weights(**kwargs)
|
98 |
+
|
99 |
+
def init_weights(self, std=0.01, zeros_init=True):
|
100 |
+
if self.mode != 'gaussian':
|
101 |
+
for m in [self.g, self.theta, self.phi]:
|
102 |
+
normal_init(m.conv, std=std)
|
103 |
+
else:
|
104 |
+
normal_init(self.g.conv, std=std)
|
105 |
+
if zeros_init:
|
106 |
+
if self.conv_out.norm_cfg is None:
|
107 |
+
constant_init(self.conv_out.conv, 0)
|
108 |
+
else:
|
109 |
+
constant_init(self.conv_out.norm, 0)
|
110 |
+
else:
|
111 |
+
if self.conv_out.norm_cfg is None:
|
112 |
+
normal_init(self.conv_out.conv, std=std)
|
113 |
+
else:
|
114 |
+
normal_init(self.conv_out.norm, std=std)
|
115 |
+
|
116 |
+
def gaussian(self, theta_x, phi_x):
|
117 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
118 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
119 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
120 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
121 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
122 |
+
return pairwise_weight
|
123 |
+
|
124 |
+
def embedded_gaussian(self, theta_x, phi_x):
|
125 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
126 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
127 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
128 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
129 |
+
if self.use_scale:
|
130 |
+
# theta_x.shape[-1] is `self.inter_channels`
|
131 |
+
pairwise_weight /= theta_x.shape[-1]**0.5
|
132 |
+
pairwise_weight = pairwise_weight.softmax(dim=-1)
|
133 |
+
return pairwise_weight
|
134 |
+
|
135 |
+
def dot_product(self, theta_x, phi_x):
|
136 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
137 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
138 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
139 |
+
pairwise_weight = torch.matmul(theta_x, phi_x)
|
140 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
141 |
+
return pairwise_weight
|
142 |
+
|
143 |
+
def concatenation(self, theta_x, phi_x):
|
144 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
145 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
146 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
147 |
+
h = theta_x.size(2)
|
148 |
+
w = phi_x.size(3)
|
149 |
+
theta_x = theta_x.repeat(1, 1, 1, w)
|
150 |
+
phi_x = phi_x.repeat(1, 1, h, 1)
|
151 |
+
|
152 |
+
concat_feature = torch.cat([theta_x, phi_x], dim=1)
|
153 |
+
pairwise_weight = self.concat_project(concat_feature)
|
154 |
+
n, _, h, w = pairwise_weight.size()
|
155 |
+
pairwise_weight = pairwise_weight.view(n, h, w)
|
156 |
+
pairwise_weight /= pairwise_weight.shape[-1]
|
157 |
+
|
158 |
+
return pairwise_weight
|
159 |
+
|
160 |
+
def forward(self, x):
|
161 |
+
# Assume `reduction = 1`, then `inter_channels = C`
|
162 |
+
# or `inter_channels = C` when `mode="gaussian"`
|
163 |
+
|
164 |
+
# NonLocal1d x: [N, C, H]
|
165 |
+
# NonLocal2d x: [N, C, H, W]
|
166 |
+
# NonLocal3d x: [N, C, T, H, W]
|
167 |
+
n = x.size(0)
|
168 |
+
|
169 |
+
# NonLocal1d g_x: [N, H, C]
|
170 |
+
# NonLocal2d g_x: [N, HxW, C]
|
171 |
+
# NonLocal3d g_x: [N, TxHxW, C]
|
172 |
+
g_x = self.g(x).view(n, self.inter_channels, -1)
|
173 |
+
g_x = g_x.permute(0, 2, 1)
|
174 |
+
|
175 |
+
# NonLocal1d theta_x: [N, H, C], phi_x: [N, C, H]
|
176 |
+
# NonLocal2d theta_x: [N, HxW, C], phi_x: [N, C, HxW]
|
177 |
+
# NonLocal3d theta_x: [N, TxHxW, C], phi_x: [N, C, TxHxW]
|
178 |
+
if self.mode == 'gaussian':
|
179 |
+
theta_x = x.view(n, self.in_channels, -1)
|
180 |
+
theta_x = theta_x.permute(0, 2, 1)
|
181 |
+
if self.sub_sample:
|
182 |
+
phi_x = self.phi(x).view(n, self.in_channels, -1)
|
183 |
+
else:
|
184 |
+
phi_x = x.view(n, self.in_channels, -1)
|
185 |
+
elif self.mode == 'concatenation':
|
186 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1, 1)
|
187 |
+
phi_x = self.phi(x).view(n, self.inter_channels, 1, -1)
|
188 |
+
else:
|
189 |
+
theta_x = self.theta(x).view(n, self.inter_channels, -1)
|
190 |
+
theta_x = theta_x.permute(0, 2, 1)
|
191 |
+
phi_x = self.phi(x).view(n, self.inter_channels, -1)
|
192 |
+
|
193 |
+
pairwise_func = getattr(self, self.mode)
|
194 |
+
# NonLocal1d pairwise_weight: [N, H, H]
|
195 |
+
# NonLocal2d pairwise_weight: [N, HxW, HxW]
|
196 |
+
# NonLocal3d pairwise_weight: [N, TxHxW, TxHxW]
|
197 |
+
pairwise_weight = pairwise_func(theta_x, phi_x)
|
198 |
+
|
199 |
+
# NonLocal1d y: [N, H, C]
|
200 |
+
# NonLocal2d y: [N, HxW, C]
|
201 |
+
# NonLocal3d y: [N, TxHxW, C]
|
202 |
+
y = torch.matmul(pairwise_weight, g_x)
|
203 |
+
# NonLocal1d y: [N, C, H]
|
204 |
+
# NonLocal2d y: [N, C, H, W]
|
205 |
+
# NonLocal3d y: [N, C, T, H, W]
|
206 |
+
y = y.permute(0, 2, 1).contiguous().reshape(n, self.inter_channels,
|
207 |
+
*x.size()[2:])
|
208 |
+
|
209 |
+
output = x + self.conv_out(y)
|
210 |
+
|
211 |
+
return output
|
212 |
+
|
213 |
+
|
214 |
+
class NonLocal1d(_NonLocalNd):
|
215 |
+
"""1D Non-local module.
|
216 |
+
|
217 |
+
Args:
|
218 |
+
in_channels (int): Same as `NonLocalND`.
|
219 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
220 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
221 |
+
Default: False.
|
222 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
223 |
+
Default: dict(type='Conv1d').
|
224 |
+
"""
|
225 |
+
|
226 |
+
def __init__(self,
|
227 |
+
in_channels,
|
228 |
+
sub_sample=False,
|
229 |
+
conv_cfg=dict(type='Conv1d'),
|
230 |
+
**kwargs):
|
231 |
+
super(NonLocal1d, self).__init__(
|
232 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
233 |
+
|
234 |
+
self.sub_sample = sub_sample
|
235 |
+
|
236 |
+
if sub_sample:
|
237 |
+
max_pool_layer = nn.MaxPool1d(kernel_size=2)
|
238 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
239 |
+
if self.mode != 'gaussian':
|
240 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
241 |
+
else:
|
242 |
+
self.phi = max_pool_layer
|
243 |
+
|
244 |
+
|
245 |
+
@PLUGIN_LAYERS.register_module()
|
246 |
+
class NonLocal2d(_NonLocalNd):
|
247 |
+
"""2D Non-local module.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
in_channels (int): Same as `NonLocalND`.
|
251 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
252 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
253 |
+
Default: False.
|
254 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
255 |
+
Default: dict(type='Conv2d').
|
256 |
+
"""
|
257 |
+
|
258 |
+
_abbr_ = 'nonlocal_block'
|
259 |
+
|
260 |
+
def __init__(self,
|
261 |
+
in_channels,
|
262 |
+
sub_sample=False,
|
263 |
+
conv_cfg=dict(type='Conv2d'),
|
264 |
+
**kwargs):
|
265 |
+
super(NonLocal2d, self).__init__(
|
266 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
267 |
+
|
268 |
+
self.sub_sample = sub_sample
|
269 |
+
|
270 |
+
if sub_sample:
|
271 |
+
max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
|
272 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
273 |
+
if self.mode != 'gaussian':
|
274 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
275 |
+
else:
|
276 |
+
self.phi = max_pool_layer
|
277 |
+
|
278 |
+
|
279 |
+
class NonLocal3d(_NonLocalNd):
|
280 |
+
"""3D Non-local module.
|
281 |
+
|
282 |
+
Args:
|
283 |
+
in_channels (int): Same as `NonLocalND`.
|
284 |
+
sub_sample (bool): Whether to apply max pooling after pairwise
|
285 |
+
function (Note that the `sub_sample` is applied on spatial only).
|
286 |
+
Default: False.
|
287 |
+
conv_cfg (None | dict): Same as `NonLocalND`.
|
288 |
+
Default: dict(type='Conv3d').
|
289 |
+
"""
|
290 |
+
|
291 |
+
def __init__(self,
|
292 |
+
in_channels,
|
293 |
+
sub_sample=False,
|
294 |
+
conv_cfg=dict(type='Conv3d'),
|
295 |
+
**kwargs):
|
296 |
+
super(NonLocal3d, self).__init__(
|
297 |
+
in_channels, conv_cfg=conv_cfg, **kwargs)
|
298 |
+
self.sub_sample = sub_sample
|
299 |
+
|
300 |
+
if sub_sample:
|
301 |
+
max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
|
302 |
+
self.g = nn.Sequential(self.g, max_pool_layer)
|
303 |
+
if self.mode != 'gaussian':
|
304 |
+
self.phi = nn.Sequential(self.phi, max_pool_layer)
|
305 |
+
else:
|
306 |
+
self.phi = max_pool_layer
|
custom_mmpkg/custom_mmcv/cnn/bricks/norm.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import inspect
|
3 |
+
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from custom_mmpkg.custom_mmcv.utils import is_tuple_of
|
7 |
+
from custom_mmpkg.custom_mmcv.utils.parrots_wrapper import SyncBatchNorm, _BatchNorm, _InstanceNorm
|
8 |
+
from .registry import NORM_LAYERS
|
9 |
+
|
10 |
+
NORM_LAYERS.register_module('BN', module=nn.BatchNorm2d)
|
11 |
+
NORM_LAYERS.register_module('BN1d', module=nn.BatchNorm1d)
|
12 |
+
NORM_LAYERS.register_module('BN2d', module=nn.BatchNorm2d)
|
13 |
+
NORM_LAYERS.register_module('BN3d', module=nn.BatchNorm3d)
|
14 |
+
NORM_LAYERS.register_module('SyncBN', module=SyncBatchNorm)
|
15 |
+
NORM_LAYERS.register_module('GN', module=nn.GroupNorm)
|
16 |
+
NORM_LAYERS.register_module('LN', module=nn.LayerNorm)
|
17 |
+
NORM_LAYERS.register_module('IN', module=nn.InstanceNorm2d)
|
18 |
+
NORM_LAYERS.register_module('IN1d', module=nn.InstanceNorm1d)
|
19 |
+
NORM_LAYERS.register_module('IN2d', module=nn.InstanceNorm2d)
|
20 |
+
NORM_LAYERS.register_module('IN3d', module=nn.InstanceNorm3d)
|
21 |
+
|
22 |
+
|
23 |
+
def infer_abbr(class_type):
|
24 |
+
"""Infer abbreviation from the class name.
|
25 |
+
|
26 |
+
When we build a norm layer with `build_norm_layer()`, we want to preserve
|
27 |
+
the norm type in variable names, e.g, self.bn1, self.gn. This method will
|
28 |
+
infer the abbreviation to map class types to abbreviations.
|
29 |
+
|
30 |
+
Rule 1: If the class has the property "_abbr_", return the property.
|
31 |
+
Rule 2: If the parent class is _BatchNorm, GroupNorm, LayerNorm or
|
32 |
+
InstanceNorm, the abbreviation of this layer will be "bn", "gn", "ln" and
|
33 |
+
"in" respectively.
|
34 |
+
Rule 3: If the class name contains "batch", "group", "layer" or "instance",
|
35 |
+
the abbreviation of this layer will be "bn", "gn", "ln" and "in"
|
36 |
+
respectively.
|
37 |
+
Rule 4: Otherwise, the abbreviation falls back to "norm".
|
38 |
+
|
39 |
+
Args:
|
40 |
+
class_type (type): The norm layer type.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
str: The inferred abbreviation.
|
44 |
+
"""
|
45 |
+
if not inspect.isclass(class_type):
|
46 |
+
raise TypeError(
|
47 |
+
f'class_type must be a type, but got {type(class_type)}')
|
48 |
+
if hasattr(class_type, '_abbr_'):
|
49 |
+
return class_type._abbr_
|
50 |
+
if issubclass(class_type, _InstanceNorm): # IN is a subclass of BN
|
51 |
+
return 'in'
|
52 |
+
elif issubclass(class_type, _BatchNorm):
|
53 |
+
return 'bn'
|
54 |
+
elif issubclass(class_type, nn.GroupNorm):
|
55 |
+
return 'gn'
|
56 |
+
elif issubclass(class_type, nn.LayerNorm):
|
57 |
+
return 'ln'
|
58 |
+
else:
|
59 |
+
class_name = class_type.__name__.lower()
|
60 |
+
if 'batch' in class_name:
|
61 |
+
return 'bn'
|
62 |
+
elif 'group' in class_name:
|
63 |
+
return 'gn'
|
64 |
+
elif 'layer' in class_name:
|
65 |
+
return 'ln'
|
66 |
+
elif 'instance' in class_name:
|
67 |
+
return 'in'
|
68 |
+
else:
|
69 |
+
return 'norm_layer'
|
70 |
+
|
71 |
+
|
72 |
+
def build_norm_layer(cfg, num_features, postfix=''):
|
73 |
+
"""Build normalization layer.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
cfg (dict): The norm layer config, which should contain:
|
77 |
+
|
78 |
+
- type (str): Layer type.
|
79 |
+
- layer args: Args needed to instantiate a norm layer.
|
80 |
+
- requires_grad (bool, optional): Whether stop gradient updates.
|
81 |
+
num_features (int): Number of input channels.
|
82 |
+
postfix (int | str): The postfix to be appended into norm abbreviation
|
83 |
+
to create named layer.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
(str, nn.Module): The first element is the layer name consisting of
|
87 |
+
abbreviation and postfix, e.g., bn1, gn. The second element is the
|
88 |
+
created norm layer.
|
89 |
+
"""
|
90 |
+
if not isinstance(cfg, dict):
|
91 |
+
raise TypeError('cfg must be a dict')
|
92 |
+
if 'type' not in cfg:
|
93 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
94 |
+
cfg_ = cfg.copy()
|
95 |
+
|
96 |
+
layer_type = cfg_.pop('type')
|
97 |
+
if layer_type not in NORM_LAYERS:
|
98 |
+
raise KeyError(f'Unrecognized norm type {layer_type}')
|
99 |
+
|
100 |
+
norm_layer = NORM_LAYERS.get(layer_type)
|
101 |
+
abbr = infer_abbr(norm_layer)
|
102 |
+
|
103 |
+
assert isinstance(postfix, (int, str))
|
104 |
+
name = abbr + str(postfix)
|
105 |
+
|
106 |
+
requires_grad = cfg_.pop('requires_grad', True)
|
107 |
+
cfg_.setdefault('eps', 1e-5)
|
108 |
+
if layer_type != 'GN':
|
109 |
+
layer = norm_layer(num_features, **cfg_)
|
110 |
+
if layer_type == 'SyncBN' and hasattr(layer, '_specify_ddp_gpu_num'):
|
111 |
+
layer._specify_ddp_gpu_num(1)
|
112 |
+
else:
|
113 |
+
assert 'num_groups' in cfg_
|
114 |
+
layer = norm_layer(num_channels=num_features, **cfg_)
|
115 |
+
|
116 |
+
for param in layer.parameters():
|
117 |
+
param.requires_grad = requires_grad
|
118 |
+
|
119 |
+
return name, layer
|
120 |
+
|
121 |
+
|
122 |
+
def is_norm(layer, exclude=None):
|
123 |
+
"""Check if a layer is a normalization layer.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
layer (nn.Module): The layer to be checked.
|
127 |
+
exclude (type | tuple[type]): Types to be excluded.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
bool: Whether the layer is a norm layer.
|
131 |
+
"""
|
132 |
+
if exclude is not None:
|
133 |
+
if not isinstance(exclude, tuple):
|
134 |
+
exclude = (exclude, )
|
135 |
+
if not is_tuple_of(exclude, type):
|
136 |
+
raise TypeError(
|
137 |
+
f'"exclude" must be either None or type or a tuple of types, '
|
138 |
+
f'but got {type(exclude)}: {exclude}')
|
139 |
+
|
140 |
+
if exclude and isinstance(layer, exclude):
|
141 |
+
return False
|
142 |
+
|
143 |
+
all_norm_bases = (_BatchNorm, _InstanceNorm, nn.GroupNorm, nn.LayerNorm)
|
144 |
+
return isinstance(layer, all_norm_bases)
|
custom_mmpkg/custom_mmcv/cnn/bricks/padding.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .registry import PADDING_LAYERS
|
5 |
+
|
6 |
+
PADDING_LAYERS.register_module('zero', module=nn.ZeroPad2d)
|
7 |
+
PADDING_LAYERS.register_module('reflect', module=nn.ReflectionPad2d)
|
8 |
+
PADDING_LAYERS.register_module('replicate', module=nn.ReplicationPad2d)
|
9 |
+
|
10 |
+
|
11 |
+
def build_padding_layer(cfg, *args, **kwargs):
|
12 |
+
"""Build padding layer.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
cfg (None or dict): The padding layer config, which should contain:
|
16 |
+
- type (str): Layer type.
|
17 |
+
- layer args: Args needed to instantiate a padding layer.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
nn.Module: Created padding layer.
|
21 |
+
"""
|
22 |
+
if not isinstance(cfg, dict):
|
23 |
+
raise TypeError('cfg must be a dict')
|
24 |
+
if 'type' not in cfg:
|
25 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
26 |
+
|
27 |
+
cfg_ = cfg.copy()
|
28 |
+
padding_type = cfg_.pop('type')
|
29 |
+
if padding_type not in PADDING_LAYERS:
|
30 |
+
raise KeyError(f'Unrecognized padding type {padding_type}.')
|
31 |
+
else:
|
32 |
+
padding_layer = PADDING_LAYERS.get(padding_type)
|
33 |
+
|
34 |
+
layer = padding_layer(*args, **kwargs, **cfg_)
|
35 |
+
|
36 |
+
return layer
|
custom_mmpkg/custom_mmcv/cnn/bricks/plugin.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import platform
|
3 |
+
|
4 |
+
from .registry import PLUGIN_LAYERS
|
5 |
+
|
6 |
+
if platform.system() == 'Windows':
|
7 |
+
import regex as re
|
8 |
+
else:
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
def infer_abbr(class_type):
|
13 |
+
"""Infer abbreviation from the class name.
|
14 |
+
|
15 |
+
This method will infer the abbreviation to map class types to
|
16 |
+
abbreviations.
|
17 |
+
|
18 |
+
Rule 1: If the class has the property "abbr", return the property.
|
19 |
+
Rule 2: Otherwise, the abbreviation falls back to snake case of class
|
20 |
+
name, e.g. the abbreviation of ``FancyBlock`` will be ``fancy_block``.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
class_type (type): The norm layer type.
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
str: The inferred abbreviation.
|
27 |
+
"""
|
28 |
+
|
29 |
+
def camel2snack(word):
|
30 |
+
"""Convert camel case word into snack case.
|
31 |
+
|
32 |
+
Modified from `inflection lib
|
33 |
+
<https://inflection.readthedocs.io/en/latest/#inflection.underscore>`_.
|
34 |
+
|
35 |
+
Example::
|
36 |
+
|
37 |
+
>>> camel2snack("FancyBlock")
|
38 |
+
'fancy_block'
|
39 |
+
"""
|
40 |
+
|
41 |
+
word = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', word)
|
42 |
+
word = re.sub(r'([a-z\d])([A-Z])', r'\1_\2', word)
|
43 |
+
word = word.replace('-', '_')
|
44 |
+
return word.lower()
|
45 |
+
|
46 |
+
if not inspect.isclass(class_type):
|
47 |
+
raise TypeError(
|
48 |
+
f'class_type must be a type, but got {type(class_type)}')
|
49 |
+
if hasattr(class_type, '_abbr_'):
|
50 |
+
return class_type._abbr_
|
51 |
+
else:
|
52 |
+
return camel2snack(class_type.__name__)
|
53 |
+
|
54 |
+
|
55 |
+
def build_plugin_layer(cfg, postfix='', **kwargs):
|
56 |
+
"""Build plugin layer.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
cfg (None or dict): cfg should contain:
|
60 |
+
type (str): identify plugin layer type.
|
61 |
+
layer args: args needed to instantiate a plugin layer.
|
62 |
+
postfix (int, str): appended into norm abbreviation to
|
63 |
+
create named layer. Default: ''.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
tuple[str, nn.Module]:
|
67 |
+
name (str): abbreviation + postfix
|
68 |
+
layer (nn.Module): created plugin layer
|
69 |
+
"""
|
70 |
+
if not isinstance(cfg, dict):
|
71 |
+
raise TypeError('cfg must be a dict')
|
72 |
+
if 'type' not in cfg:
|
73 |
+
raise KeyError('the cfg dict must contain the key "type"')
|
74 |
+
cfg_ = cfg.copy()
|
75 |
+
|
76 |
+
layer_type = cfg_.pop('type')
|
77 |
+
if layer_type not in PLUGIN_LAYERS:
|
78 |
+
raise KeyError(f'Unrecognized plugin type {layer_type}')
|
79 |
+
|
80 |
+
plugin_layer = PLUGIN_LAYERS.get(layer_type)
|
81 |
+
abbr = infer_abbr(plugin_layer)
|
82 |
+
|
83 |
+
assert isinstance(postfix, (int, str))
|
84 |
+
name = abbr + str(postfix)
|
85 |
+
|
86 |
+
layer = plugin_layer(**kwargs, **cfg_)
|
87 |
+
|
88 |
+
return name, layer
|
custom_mmpkg/custom_mmcv/cnn/bricks/registry.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from custom_mmpkg.custom_mmcv.utils import Registry
|
3 |
+
|
4 |
+
CONV_LAYERS = Registry('conv layer')
|
5 |
+
NORM_LAYERS = Registry('norm layer')
|
6 |
+
ACTIVATION_LAYERS = Registry('activation layer')
|
7 |
+
PADDING_LAYERS = Registry('padding layer')
|
8 |
+
UPSAMPLE_LAYERS = Registry('upsample layer')
|
9 |
+
PLUGIN_LAYERS = Registry('plugin layer')
|
10 |
+
|
11 |
+
DROPOUT_LAYERS = Registry('drop out layers')
|
12 |
+
POSITIONAL_ENCODING = Registry('position encoding')
|
13 |
+
ATTENTION = Registry('attention')
|
14 |
+
FEEDFORWARD_NETWORK = Registry('feed-forward Network')
|
15 |
+
TRANSFORMER_LAYER = Registry('transformerLayer')
|
16 |
+
TRANSFORMER_LAYER_SEQUENCE = Registry('transformer-layers sequence')
|
custom_mmpkg/custom_mmcv/cnn/bricks/scale.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
|
6 |
+
class Scale(nn.Module):
|
7 |
+
"""A learnable scale parameter.
|
8 |
+
|
9 |
+
This layer scales the input by a learnable factor. It multiplies a
|
10 |
+
learnable scale parameter of shape (1,) with input of any shape.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
scale (float): Initial value of scale factor. Default: 1.0
|
14 |
+
"""
|
15 |
+
|
16 |
+
def __init__(self, scale=1.0):
|
17 |
+
super(Scale, self).__init__()
|
18 |
+
self.scale = nn.Parameter(torch.tensor(scale, dtype=torch.float))
|
19 |
+
|
20 |
+
def forward(self, x):
|
21 |
+
return x * self.scale
|
custom_mmpkg/custom_mmcv/cnn/bricks/swish.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from .registry import ACTIVATION_LAYERS
|
6 |
+
|
7 |
+
|
8 |
+
@ACTIVATION_LAYERS.register_module()
|
9 |
+
class Swish(nn.Module):
|
10 |
+
"""Swish Module.
|
11 |
+
|
12 |
+
This module applies the swish function:
|
13 |
+
|
14 |
+
.. math::
|
15 |
+
Swish(x) = x * Sigmoid(x)
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
Tensor: The output tensor.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(self):
|
22 |
+
super(Swish, self).__init__()
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
return x * torch.sigmoid(x)
|
custom_mmpkg/custom_mmcv/cnn/bricks/transformer.py
ADDED
@@ -0,0 +1,595 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import copy
|
3 |
+
import warnings
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
from custom_mmpkg.custom_mmcv import ConfigDict, deprecated_api_warning
|
9 |
+
from custom_mmpkg.custom_mmcv.cnn import Linear, build_activation_layer, build_norm_layer
|
10 |
+
from custom_mmpkg.custom_mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
11 |
+
from custom_mmpkg.custom_mmcv.utils import build_from_cfg
|
12 |
+
from .drop import build_dropout
|
13 |
+
from .registry import (ATTENTION, FEEDFORWARD_NETWORK, POSITIONAL_ENCODING,
|
14 |
+
TRANSFORMER_LAYER, TRANSFORMER_LAYER_SEQUENCE)
|
15 |
+
|
16 |
+
# Avoid BC-breaking of importing MultiScaleDeformableAttention from this file
|
17 |
+
try:
|
18 |
+
from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention # noqa F401
|
19 |
+
warnings.warn(
|
20 |
+
ImportWarning(
|
21 |
+
'``MultiScaleDeformableAttention`` has been moved to '
|
22 |
+
'``mmcv.ops.multi_scale_deform_attn``, please change original path ' # noqa E501
|
23 |
+
'``from custom_mmpkg.custom_mmcv.cnn.bricks.transformer import MultiScaleDeformableAttention`` ' # noqa E501
|
24 |
+
'to ``from custom_mmpkg.custom_mmcv.ops.multi_scale_deform_attn import MultiScaleDeformableAttention`` ' # noqa E501
|
25 |
+
))
|
26 |
+
|
27 |
+
except ImportError:
|
28 |
+
warnings.warn('Fail to import ``MultiScaleDeformableAttention`` from '
|
29 |
+
'``mmcv.ops.multi_scale_deform_attn``, '
|
30 |
+
'You should install ``mmcv-full`` if you need this module. ')
|
31 |
+
|
32 |
+
|
33 |
+
def build_positional_encoding(cfg, default_args=None):
|
34 |
+
"""Builder for Position Encoding."""
|
35 |
+
return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
|
36 |
+
|
37 |
+
|
38 |
+
def build_attention(cfg, default_args=None):
|
39 |
+
"""Builder for attention."""
|
40 |
+
return build_from_cfg(cfg, ATTENTION, default_args)
|
41 |
+
|
42 |
+
|
43 |
+
def build_feedforward_network(cfg, default_args=None):
|
44 |
+
"""Builder for feed-forward network (FFN)."""
|
45 |
+
return build_from_cfg(cfg, FEEDFORWARD_NETWORK, default_args)
|
46 |
+
|
47 |
+
|
48 |
+
def build_transformer_layer(cfg, default_args=None):
|
49 |
+
"""Builder for transformer layer."""
|
50 |
+
return build_from_cfg(cfg, TRANSFORMER_LAYER, default_args)
|
51 |
+
|
52 |
+
|
53 |
+
def build_transformer_layer_sequence(cfg, default_args=None):
|
54 |
+
"""Builder for transformer encoder and transformer decoder."""
|
55 |
+
return build_from_cfg(cfg, TRANSFORMER_LAYER_SEQUENCE, default_args)
|
56 |
+
|
57 |
+
|
58 |
+
@ATTENTION.register_module()
|
59 |
+
class MultiheadAttention(BaseModule):
|
60 |
+
"""A wrapper for ``torch.nn.MultiheadAttention``.
|
61 |
+
|
62 |
+
This module implements MultiheadAttention with identity connection,
|
63 |
+
and positional encoding is also passed as input.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
embed_dims (int): The embedding dimension.
|
67 |
+
num_heads (int): Parallel attention heads.
|
68 |
+
attn_drop (float): A Dropout layer on attn_output_weights.
|
69 |
+
Default: 0.0.
|
70 |
+
proj_drop (float): A Dropout layer after `nn.MultiheadAttention`.
|
71 |
+
Default: 0.0.
|
72 |
+
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
73 |
+
when adding the shortcut.
|
74 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
75 |
+
Default: None.
|
76 |
+
batch_first (bool): When it is True, Key, Query and Value are shape of
|
77 |
+
(batch, n, embed_dim), otherwise (n, batch, embed_dim).
|
78 |
+
Default to False.
|
79 |
+
"""
|
80 |
+
|
81 |
+
def __init__(self,
|
82 |
+
embed_dims,
|
83 |
+
num_heads,
|
84 |
+
attn_drop=0.,
|
85 |
+
proj_drop=0.,
|
86 |
+
dropout_layer=dict(type='Dropout', drop_prob=0.),
|
87 |
+
init_cfg=None,
|
88 |
+
batch_first=False,
|
89 |
+
**kwargs):
|
90 |
+
super(MultiheadAttention, self).__init__(init_cfg)
|
91 |
+
if 'dropout' in kwargs:
|
92 |
+
warnings.warn('The arguments `dropout` in MultiheadAttention '
|
93 |
+
'has been deprecated, now you can separately '
|
94 |
+
'set `attn_drop`(float), proj_drop(float), '
|
95 |
+
'and `dropout_layer`(dict) ')
|
96 |
+
attn_drop = kwargs['dropout']
|
97 |
+
dropout_layer['drop_prob'] = kwargs.pop('dropout')
|
98 |
+
|
99 |
+
self.embed_dims = embed_dims
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.batch_first = batch_first
|
102 |
+
|
103 |
+
self.attn = nn.MultiheadAttention(embed_dims, num_heads, attn_drop,
|
104 |
+
**kwargs)
|
105 |
+
|
106 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
107 |
+
self.dropout_layer = build_dropout(
|
108 |
+
dropout_layer) if dropout_layer else nn.Identity()
|
109 |
+
|
110 |
+
@deprecated_api_warning({'residual': 'identity'},
|
111 |
+
cls_name='MultiheadAttention')
|
112 |
+
def forward(self,
|
113 |
+
query,
|
114 |
+
key=None,
|
115 |
+
value=None,
|
116 |
+
identity=None,
|
117 |
+
query_pos=None,
|
118 |
+
key_pos=None,
|
119 |
+
attn_mask=None,
|
120 |
+
key_padding_mask=None,
|
121 |
+
**kwargs):
|
122 |
+
"""Forward function for `MultiheadAttention`.
|
123 |
+
|
124 |
+
**kwargs allow passing a more general data flow when combining
|
125 |
+
with other operations in `transformerlayer`.
|
126 |
+
|
127 |
+
Args:
|
128 |
+
query (Tensor): The input query with shape [num_queries, bs,
|
129 |
+
embed_dims] if self.batch_first is False, else
|
130 |
+
[bs, num_queries embed_dims].
|
131 |
+
key (Tensor): The key tensor with shape [num_keys, bs,
|
132 |
+
embed_dims] if self.batch_first is False, else
|
133 |
+
[bs, num_keys, embed_dims] .
|
134 |
+
If None, the ``query`` will be used. Defaults to None.
|
135 |
+
value (Tensor): The value tensor with same shape as `key`.
|
136 |
+
Same in `nn.MultiheadAttention.forward`. Defaults to None.
|
137 |
+
If None, the `key` will be used.
|
138 |
+
identity (Tensor): This tensor, with the same shape as x,
|
139 |
+
will be used for the identity link.
|
140 |
+
If None, `x` will be used. Defaults to None.
|
141 |
+
query_pos (Tensor): The positional encoding for query, with
|
142 |
+
the same shape as `x`. If not None, it will
|
143 |
+
be added to `x` before forward function. Defaults to None.
|
144 |
+
key_pos (Tensor): The positional encoding for `key`, with the
|
145 |
+
same shape as `key`. Defaults to None. If not None, it will
|
146 |
+
be added to `key` before forward function. If None, and
|
147 |
+
`query_pos` has the same shape as `key`, then `query_pos`
|
148 |
+
will be used for `key_pos`. Defaults to None.
|
149 |
+
attn_mask (Tensor): ByteTensor mask with shape [num_queries,
|
150 |
+
num_keys]. Same in `nn.MultiheadAttention.forward`.
|
151 |
+
Defaults to None.
|
152 |
+
key_padding_mask (Tensor): ByteTensor with shape [bs, num_keys].
|
153 |
+
Defaults to None.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tensor: forwarded results with shape
|
157 |
+
[num_queries, bs, embed_dims]
|
158 |
+
if self.batch_first is False, else
|
159 |
+
[bs, num_queries embed_dims].
|
160 |
+
"""
|
161 |
+
|
162 |
+
if key is None:
|
163 |
+
key = query
|
164 |
+
if value is None:
|
165 |
+
value = key
|
166 |
+
if identity is None:
|
167 |
+
identity = query
|
168 |
+
if key_pos is None:
|
169 |
+
if query_pos is not None:
|
170 |
+
# use query_pos if key_pos is not available
|
171 |
+
if query_pos.shape == key.shape:
|
172 |
+
key_pos = query_pos
|
173 |
+
else:
|
174 |
+
warnings.warn(f'position encoding of key is'
|
175 |
+
f'missing in {self.__class__.__name__}.')
|
176 |
+
if query_pos is not None:
|
177 |
+
query = query + query_pos
|
178 |
+
if key_pos is not None:
|
179 |
+
key = key + key_pos
|
180 |
+
|
181 |
+
# Because the dataflow('key', 'query', 'value') of
|
182 |
+
# ``torch.nn.MultiheadAttention`` is (num_query, batch,
|
183 |
+
# embed_dims), We should adjust the shape of dataflow from
|
184 |
+
# batch_first (batch, num_query, embed_dims) to num_query_first
|
185 |
+
# (num_query ,batch, embed_dims), and recover ``attn_output``
|
186 |
+
# from num_query_first to batch_first.
|
187 |
+
if self.batch_first:
|
188 |
+
query = query.transpose(0, 1)
|
189 |
+
key = key.transpose(0, 1)
|
190 |
+
value = value.transpose(0, 1)
|
191 |
+
|
192 |
+
out = self.attn(
|
193 |
+
query=query,
|
194 |
+
key=key,
|
195 |
+
value=value,
|
196 |
+
attn_mask=attn_mask,
|
197 |
+
key_padding_mask=key_padding_mask)[0]
|
198 |
+
|
199 |
+
if self.batch_first:
|
200 |
+
out = out.transpose(0, 1)
|
201 |
+
|
202 |
+
return identity + self.dropout_layer(self.proj_drop(out))
|
203 |
+
|
204 |
+
|
205 |
+
@FEEDFORWARD_NETWORK.register_module()
|
206 |
+
class FFN(BaseModule):
|
207 |
+
"""Implements feed-forward networks (FFNs) with identity connection.
|
208 |
+
|
209 |
+
Args:
|
210 |
+
embed_dims (int): The feature dimension. Same as
|
211 |
+
`MultiheadAttention`. Defaults: 256.
|
212 |
+
feedforward_channels (int): The hidden dimension of FFNs.
|
213 |
+
Defaults: 1024.
|
214 |
+
num_fcs (int, optional): The number of fully-connected layers in
|
215 |
+
FFNs. Default: 2.
|
216 |
+
act_cfg (dict, optional): The activation config for FFNs.
|
217 |
+
Default: dict(type='ReLU')
|
218 |
+
ffn_drop (float, optional): Probability of an element to be
|
219 |
+
zeroed in FFN. Default 0.0.
|
220 |
+
add_identity (bool, optional): Whether to add the
|
221 |
+
identity connection. Default: `True`.
|
222 |
+
dropout_layer (obj:`ConfigDict`): The dropout_layer used
|
223 |
+
when adding the shortcut.
|
224 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
225 |
+
Default: None.
|
226 |
+
"""
|
227 |
+
|
228 |
+
@deprecated_api_warning(
|
229 |
+
{
|
230 |
+
'dropout': 'ffn_drop',
|
231 |
+
'add_residual': 'add_identity'
|
232 |
+
},
|
233 |
+
cls_name='FFN')
|
234 |
+
def __init__(self,
|
235 |
+
embed_dims=256,
|
236 |
+
feedforward_channels=1024,
|
237 |
+
num_fcs=2,
|
238 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
239 |
+
ffn_drop=0.,
|
240 |
+
dropout_layer=None,
|
241 |
+
add_identity=True,
|
242 |
+
init_cfg=None,
|
243 |
+
**kwargs):
|
244 |
+
super(FFN, self).__init__(init_cfg)
|
245 |
+
assert num_fcs >= 2, 'num_fcs should be no less ' \
|
246 |
+
f'than 2. got {num_fcs}.'
|
247 |
+
self.embed_dims = embed_dims
|
248 |
+
self.feedforward_channels = feedforward_channels
|
249 |
+
self.num_fcs = num_fcs
|
250 |
+
self.act_cfg = act_cfg
|
251 |
+
self.activate = build_activation_layer(act_cfg)
|
252 |
+
|
253 |
+
layers = []
|
254 |
+
in_channels = embed_dims
|
255 |
+
for _ in range(num_fcs - 1):
|
256 |
+
layers.append(
|
257 |
+
Sequential(
|
258 |
+
Linear(in_channels, feedforward_channels), self.activate,
|
259 |
+
nn.Dropout(ffn_drop)))
|
260 |
+
in_channels = feedforward_channels
|
261 |
+
layers.append(Linear(feedforward_channels, embed_dims))
|
262 |
+
layers.append(nn.Dropout(ffn_drop))
|
263 |
+
self.layers = Sequential(*layers)
|
264 |
+
self.dropout_layer = build_dropout(
|
265 |
+
dropout_layer) if dropout_layer else torch.nn.Identity()
|
266 |
+
self.add_identity = add_identity
|
267 |
+
|
268 |
+
@deprecated_api_warning({'residual': 'identity'}, cls_name='FFN')
|
269 |
+
def forward(self, x, identity=None):
|
270 |
+
"""Forward function for `FFN`.
|
271 |
+
|
272 |
+
The function would add x to the output tensor if residue is None.
|
273 |
+
"""
|
274 |
+
out = self.layers(x)
|
275 |
+
if not self.add_identity:
|
276 |
+
return self.dropout_layer(out)
|
277 |
+
if identity is None:
|
278 |
+
identity = x
|
279 |
+
return identity + self.dropout_layer(out)
|
280 |
+
|
281 |
+
|
282 |
+
@TRANSFORMER_LAYER.register_module()
|
283 |
+
class BaseTransformerLayer(BaseModule):
|
284 |
+
"""Base `TransformerLayer` for vision transformer.
|
285 |
+
|
286 |
+
It can be built from `mmcv.ConfigDict` and support more flexible
|
287 |
+
customization, for example, using any number of `FFN or LN ` and
|
288 |
+
use different kinds of `attention` by specifying a list of `ConfigDict`
|
289 |
+
named `attn_cfgs`. It is worth mentioning that it supports `prenorm`
|
290 |
+
when you specifying `norm` as the first element of `operation_order`.
|
291 |
+
More details about the `prenorm`: `On Layer Normalization in the
|
292 |
+
Transformer Architecture <https://arxiv.org/abs/2002.04745>`_ .
|
293 |
+
|
294 |
+
Args:
|
295 |
+
attn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
296 |
+
Configs for `self_attention` or `cross_attention` modules,
|
297 |
+
The order of the configs in the list should be consistent with
|
298 |
+
corresponding attentions in operation_order.
|
299 |
+
If it is a dict, all of the attention modules in operation_order
|
300 |
+
will be built with this config. Default: None.
|
301 |
+
ffn_cfgs (list[`mmcv.ConfigDict`] | obj:`mmcv.ConfigDict` | None )):
|
302 |
+
Configs for FFN, The order of the configs in the list should be
|
303 |
+
consistent with corresponding ffn in operation_order.
|
304 |
+
If it is a dict, all of the attention modules in operation_order
|
305 |
+
will be built with this config.
|
306 |
+
operation_order (tuple[str]): The execution order of operation
|
307 |
+
in transformer. Such as ('self_attn', 'norm', 'ffn', 'norm').
|
308 |
+
Support `prenorm` when you specifying first element as `norm`.
|
309 |
+
Default:None.
|
310 |
+
norm_cfg (dict): Config dict for normalization layer.
|
311 |
+
Default: dict(type='LN').
|
312 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
313 |
+
Default: None.
|
314 |
+
batch_first (bool): Key, Query and Value are shape
|
315 |
+
of (batch, n, embed_dim)
|
316 |
+
or (n, batch, embed_dim). Default to False.
|
317 |
+
"""
|
318 |
+
|
319 |
+
def __init__(self,
|
320 |
+
attn_cfgs=None,
|
321 |
+
ffn_cfgs=dict(
|
322 |
+
type='FFN',
|
323 |
+
embed_dims=256,
|
324 |
+
feedforward_channels=1024,
|
325 |
+
num_fcs=2,
|
326 |
+
ffn_drop=0.,
|
327 |
+
act_cfg=dict(type='ReLU', inplace=True),
|
328 |
+
),
|
329 |
+
operation_order=None,
|
330 |
+
norm_cfg=dict(type='LN'),
|
331 |
+
init_cfg=None,
|
332 |
+
batch_first=False,
|
333 |
+
**kwargs):
|
334 |
+
|
335 |
+
deprecated_args = dict(
|
336 |
+
feedforward_channels='feedforward_channels',
|
337 |
+
ffn_dropout='ffn_drop',
|
338 |
+
ffn_num_fcs='num_fcs')
|
339 |
+
for ori_name, new_name in deprecated_args.items():
|
340 |
+
if ori_name in kwargs:
|
341 |
+
warnings.warn(
|
342 |
+
f'The arguments `{ori_name}` in BaseTransformerLayer '
|
343 |
+
f'has been deprecated, now you should set `{new_name}` '
|
344 |
+
f'and other FFN related arguments '
|
345 |
+
f'to a dict named `ffn_cfgs`. ')
|
346 |
+
ffn_cfgs[new_name] = kwargs[ori_name]
|
347 |
+
|
348 |
+
super(BaseTransformerLayer, self).__init__(init_cfg)
|
349 |
+
|
350 |
+
self.batch_first = batch_first
|
351 |
+
|
352 |
+
assert set(operation_order) & set(
|
353 |
+
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
|
354 |
+
set(operation_order), f'The operation_order of' \
|
355 |
+
f' {self.__class__.__name__} should ' \
|
356 |
+
f'contains all four operation type ' \
|
357 |
+
f"{['self_attn', 'norm', 'ffn', 'cross_attn']}"
|
358 |
+
|
359 |
+
num_attn = operation_order.count('self_attn') + operation_order.count(
|
360 |
+
'cross_attn')
|
361 |
+
if isinstance(attn_cfgs, dict):
|
362 |
+
attn_cfgs = [copy.deepcopy(attn_cfgs) for _ in range(num_attn)]
|
363 |
+
else:
|
364 |
+
assert num_attn == len(attn_cfgs), f'The length ' \
|
365 |
+
f'of attn_cfg {num_attn} is ' \
|
366 |
+
f'not consistent with the number of attention' \
|
367 |
+
f'in operation_order {operation_order}.'
|
368 |
+
|
369 |
+
self.num_attn = num_attn
|
370 |
+
self.operation_order = operation_order
|
371 |
+
self.norm_cfg = norm_cfg
|
372 |
+
self.pre_norm = operation_order[0] == 'norm'
|
373 |
+
self.attentions = ModuleList()
|
374 |
+
|
375 |
+
index = 0
|
376 |
+
for operation_name in operation_order:
|
377 |
+
if operation_name in ['self_attn', 'cross_attn']:
|
378 |
+
if 'batch_first' in attn_cfgs[index]:
|
379 |
+
assert self.batch_first == attn_cfgs[index]['batch_first']
|
380 |
+
else:
|
381 |
+
attn_cfgs[index]['batch_first'] = self.batch_first
|
382 |
+
attention = build_attention(attn_cfgs[index])
|
383 |
+
# Some custom attentions used as `self_attn`
|
384 |
+
# or `cross_attn` can have different behavior.
|
385 |
+
attention.operation_name = operation_name
|
386 |
+
self.attentions.append(attention)
|
387 |
+
index += 1
|
388 |
+
|
389 |
+
self.embed_dims = self.attentions[0].embed_dims
|
390 |
+
|
391 |
+
self.ffns = ModuleList()
|
392 |
+
num_ffns = operation_order.count('ffn')
|
393 |
+
if isinstance(ffn_cfgs, dict):
|
394 |
+
ffn_cfgs = ConfigDict(ffn_cfgs)
|
395 |
+
if isinstance(ffn_cfgs, dict):
|
396 |
+
ffn_cfgs = [copy.deepcopy(ffn_cfgs) for _ in range(num_ffns)]
|
397 |
+
assert len(ffn_cfgs) == num_ffns
|
398 |
+
for ffn_index in range(num_ffns):
|
399 |
+
if 'embed_dims' not in ffn_cfgs[ffn_index]:
|
400 |
+
ffn_cfgs['embed_dims'] = self.embed_dims
|
401 |
+
else:
|
402 |
+
assert ffn_cfgs[ffn_index]['embed_dims'] == self.embed_dims
|
403 |
+
self.ffns.append(
|
404 |
+
build_feedforward_network(ffn_cfgs[ffn_index],
|
405 |
+
dict(type='FFN')))
|
406 |
+
|
407 |
+
self.norms = ModuleList()
|
408 |
+
num_norms = operation_order.count('norm')
|
409 |
+
for _ in range(num_norms):
|
410 |
+
self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1])
|
411 |
+
|
412 |
+
def forward(self,
|
413 |
+
query,
|
414 |
+
key=None,
|
415 |
+
value=None,
|
416 |
+
query_pos=None,
|
417 |
+
key_pos=None,
|
418 |
+
attn_masks=None,
|
419 |
+
query_key_padding_mask=None,
|
420 |
+
key_padding_mask=None,
|
421 |
+
**kwargs):
|
422 |
+
"""Forward function for `TransformerDecoderLayer`.
|
423 |
+
|
424 |
+
**kwargs contains some specific arguments of attentions.
|
425 |
+
|
426 |
+
Args:
|
427 |
+
query (Tensor): The input query with shape
|
428 |
+
[num_queries, bs, embed_dims] if
|
429 |
+
self.batch_first is False, else
|
430 |
+
[bs, num_queries embed_dims].
|
431 |
+
key (Tensor): The key tensor with shape [num_keys, bs,
|
432 |
+
embed_dims] if self.batch_first is False, else
|
433 |
+
[bs, num_keys, embed_dims] .
|
434 |
+
value (Tensor): The value tensor with same shape as `key`.
|
435 |
+
query_pos (Tensor): The positional encoding for `query`.
|
436 |
+
Default: None.
|
437 |
+
key_pos (Tensor): The positional encoding for `key`.
|
438 |
+
Default: None.
|
439 |
+
attn_masks (List[Tensor] | None): 2D Tensor used in
|
440 |
+
calculation of corresponding attention. The length of
|
441 |
+
it should equal to the number of `attention` in
|
442 |
+
`operation_order`. Default: None.
|
443 |
+
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
444 |
+
shape [bs, num_queries]. Only used in `self_attn` layer.
|
445 |
+
Defaults to None.
|
446 |
+
key_padding_mask (Tensor): ByteTensor for `query`, with
|
447 |
+
shape [bs, num_keys]. Default: None.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
Tensor: forwarded results with shape [num_queries, bs, embed_dims].
|
451 |
+
"""
|
452 |
+
|
453 |
+
norm_index = 0
|
454 |
+
attn_index = 0
|
455 |
+
ffn_index = 0
|
456 |
+
identity = query
|
457 |
+
if attn_masks is None:
|
458 |
+
attn_masks = [None for _ in range(self.num_attn)]
|
459 |
+
elif isinstance(attn_masks, torch.Tensor):
|
460 |
+
attn_masks = [
|
461 |
+
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
|
462 |
+
]
|
463 |
+
warnings.warn(f'Use same attn_mask in all attentions in '
|
464 |
+
f'{self.__class__.__name__} ')
|
465 |
+
else:
|
466 |
+
assert len(attn_masks) == self.num_attn, f'The length of ' \
|
467 |
+
f'attn_masks {len(attn_masks)} must be equal ' \
|
468 |
+
f'to the number of attention in ' \
|
469 |
+
f'operation_order {self.num_attn}'
|
470 |
+
|
471 |
+
for layer in self.operation_order:
|
472 |
+
if layer == 'self_attn':
|
473 |
+
temp_key = temp_value = query
|
474 |
+
query = self.attentions[attn_index](
|
475 |
+
query,
|
476 |
+
temp_key,
|
477 |
+
temp_value,
|
478 |
+
identity if self.pre_norm else None,
|
479 |
+
query_pos=query_pos,
|
480 |
+
key_pos=query_pos,
|
481 |
+
attn_mask=attn_masks[attn_index],
|
482 |
+
key_padding_mask=query_key_padding_mask,
|
483 |
+
**kwargs)
|
484 |
+
attn_index += 1
|
485 |
+
identity = query
|
486 |
+
|
487 |
+
elif layer == 'norm':
|
488 |
+
query = self.norms[norm_index](query)
|
489 |
+
norm_index += 1
|
490 |
+
|
491 |
+
elif layer == 'cross_attn':
|
492 |
+
query = self.attentions[attn_index](
|
493 |
+
query,
|
494 |
+
key,
|
495 |
+
value,
|
496 |
+
identity if self.pre_norm else None,
|
497 |
+
query_pos=query_pos,
|
498 |
+
key_pos=key_pos,
|
499 |
+
attn_mask=attn_masks[attn_index],
|
500 |
+
key_padding_mask=key_padding_mask,
|
501 |
+
**kwargs)
|
502 |
+
attn_index += 1
|
503 |
+
identity = query
|
504 |
+
|
505 |
+
elif layer == 'ffn':
|
506 |
+
query = self.ffns[ffn_index](
|
507 |
+
query, identity if self.pre_norm else None)
|
508 |
+
ffn_index += 1
|
509 |
+
|
510 |
+
return query
|
511 |
+
|
512 |
+
|
513 |
+
@TRANSFORMER_LAYER_SEQUENCE.register_module()
|
514 |
+
class TransformerLayerSequence(BaseModule):
|
515 |
+
"""Base class for TransformerEncoder and TransformerDecoder in vision
|
516 |
+
transformer.
|
517 |
+
|
518 |
+
As base-class of Encoder and Decoder in vision transformer.
|
519 |
+
Support customization such as specifying different kind
|
520 |
+
of `transformer_layer` in `transformer_coder`.
|
521 |
+
|
522 |
+
Args:
|
523 |
+
transformerlayer (list[obj:`mmcv.ConfigDict`] |
|
524 |
+
obj:`mmcv.ConfigDict`): Config of transformerlayer
|
525 |
+
in TransformerCoder. If it is obj:`mmcv.ConfigDict`,
|
526 |
+
it would be repeated `num_layer` times to a
|
527 |
+
list[`mmcv.ConfigDict`]. Default: None.
|
528 |
+
num_layers (int): The number of `TransformerLayer`. Default: None.
|
529 |
+
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
530 |
+
Default: None.
|
531 |
+
"""
|
532 |
+
|
533 |
+
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
|
534 |
+
super(TransformerLayerSequence, self).__init__(init_cfg)
|
535 |
+
if isinstance(transformerlayers, dict):
|
536 |
+
transformerlayers = [
|
537 |
+
copy.deepcopy(transformerlayers) for _ in range(num_layers)
|
538 |
+
]
|
539 |
+
else:
|
540 |
+
assert isinstance(transformerlayers, list) and \
|
541 |
+
len(transformerlayers) == num_layers
|
542 |
+
self.num_layers = num_layers
|
543 |
+
self.layers = ModuleList()
|
544 |
+
for i in range(num_layers):
|
545 |
+
self.layers.append(build_transformer_layer(transformerlayers[i]))
|
546 |
+
self.embed_dims = self.layers[0].embed_dims
|
547 |
+
self.pre_norm = self.layers[0].pre_norm
|
548 |
+
|
549 |
+
def forward(self,
|
550 |
+
query,
|
551 |
+
key,
|
552 |
+
value,
|
553 |
+
query_pos=None,
|
554 |
+
key_pos=None,
|
555 |
+
attn_masks=None,
|
556 |
+
query_key_padding_mask=None,
|
557 |
+
key_padding_mask=None,
|
558 |
+
**kwargs):
|
559 |
+
"""Forward function for `TransformerCoder`.
|
560 |
+
|
561 |
+
Args:
|
562 |
+
query (Tensor): Input query with shape
|
563 |
+
`(num_queries, bs, embed_dims)`.
|
564 |
+
key (Tensor): The key tensor with shape
|
565 |
+
`(num_keys, bs, embed_dims)`.
|
566 |
+
value (Tensor): The value tensor with shape
|
567 |
+
`(num_keys, bs, embed_dims)`.
|
568 |
+
query_pos (Tensor): The positional encoding for `query`.
|
569 |
+
Default: None.
|
570 |
+
key_pos (Tensor): The positional encoding for `key`.
|
571 |
+
Default: None.
|
572 |
+
attn_masks (List[Tensor], optional): Each element is 2D Tensor
|
573 |
+
which is used in calculation of corresponding attention in
|
574 |
+
operation_order. Default: None.
|
575 |
+
query_key_padding_mask (Tensor): ByteTensor for `query`, with
|
576 |
+
shape [bs, num_queries]. Only used in self-attention
|
577 |
+
Default: None.
|
578 |
+
key_padding_mask (Tensor): ByteTensor for `query`, with
|
579 |
+
shape [bs, num_keys]. Default: None.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
Tensor: results with shape [num_queries, bs, embed_dims].
|
583 |
+
"""
|
584 |
+
for layer in self.layers:
|
585 |
+
query = layer(
|
586 |
+
query,
|
587 |
+
key,
|
588 |
+
value,
|
589 |
+
query_pos=query_pos,
|
590 |
+
key_pos=key_pos,
|
591 |
+
attn_masks=attn_masks,
|
592 |
+
query_key_padding_mask=query_key_padding_mask,
|
593 |
+
key_padding_mask=key_padding_mask,
|
594 |
+
**kwargs)
|
595 |
+
return query
|
custom_mmpkg/custom_mmcv/cnn/bricks/upsample.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from ..utils import xavier_init
|
6 |
+
from .registry import UPSAMPLE_LAYERS
|
7 |
+
|
8 |
+
UPSAMPLE_LAYERS.register_module('nearest', module=nn.Upsample)
|
9 |
+
UPSAMPLE_LAYERS.register_module('bilinear', module=nn.Upsample)
|
10 |
+
|
11 |
+
|
12 |
+
@UPSAMPLE_LAYERS.register_module(name='pixel_shuffle')
|
13 |
+
class PixelShufflePack(nn.Module):
|
14 |
+
"""Pixel Shuffle upsample layer.
|
15 |
+
|
16 |
+
This module packs `F.pixel_shuffle()` and a nn.Conv2d module together to
|
17 |
+
achieve a simple upsampling with pixel shuffle.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
in_channels (int): Number of input channels.
|
21 |
+
out_channels (int): Number of output channels.
|
22 |
+
scale_factor (int): Upsample ratio.
|
23 |
+
upsample_kernel (int): Kernel size of the conv layer to expand the
|
24 |
+
channels.
|
25 |
+
"""
|
26 |
+
|
27 |
+
def __init__(self, in_channels, out_channels, scale_factor,
|
28 |
+
upsample_kernel):
|
29 |
+
super(PixelShufflePack, self).__init__()
|
30 |
+
self.in_channels = in_channels
|
31 |
+
self.out_channels = out_channels
|
32 |
+
self.scale_factor = scale_factor
|
33 |
+
self.upsample_kernel = upsample_kernel
|
34 |
+
self.upsample_conv = nn.Conv2d(
|
35 |
+
self.in_channels,
|
36 |
+
self.out_channels * scale_factor * scale_factor,
|
37 |
+
self.upsample_kernel,
|
38 |
+
padding=(self.upsample_kernel - 1) // 2)
|
39 |
+
self.init_weights()
|
40 |
+
|
41 |
+
def init_weights(self):
|
42 |
+
xavier_init(self.upsample_conv, distribution='uniform')
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
x = self.upsample_conv(x)
|
46 |
+
x = F.pixel_shuffle(x, self.scale_factor)
|
47 |
+
return x
|
48 |
+
|
49 |
+
|
50 |
+
def build_upsample_layer(cfg, *args, **kwargs):
|
51 |
+
"""Build upsample layer.
|
52 |
+
|
53 |
+
Args:
|
54 |
+
cfg (dict): The upsample layer config, which should contain:
|
55 |
+
|
56 |
+
- type (str): Layer type.
|
57 |
+
- scale_factor (int): Upsample ratio, which is not applicable to
|
58 |
+
deconv.
|
59 |
+
- layer args: Args needed to instantiate a upsample layer.
|
60 |
+
args (argument list): Arguments passed to the ``__init__``
|
61 |
+
method of the corresponding conv layer.
|
62 |
+
kwargs (keyword arguments): Keyword arguments passed to the
|
63 |
+
``__init__`` method of the corresponding conv layer.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
nn.Module: Created upsample layer.
|
67 |
+
"""
|
68 |
+
if not isinstance(cfg, dict):
|
69 |
+
raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
|
70 |
+
if 'type' not in cfg:
|
71 |
+
raise KeyError(
|
72 |
+
f'the cfg dict must contain the key "type", but got {cfg}')
|
73 |
+
cfg_ = cfg.copy()
|
74 |
+
|
75 |
+
layer_type = cfg_.pop('type')
|
76 |
+
if layer_type not in UPSAMPLE_LAYERS:
|
77 |
+
raise KeyError(f'Unrecognized upsample type {layer_type}')
|
78 |
+
else:
|
79 |
+
upsample = UPSAMPLE_LAYERS.get(layer_type)
|
80 |
+
|
81 |
+
if upsample is nn.Upsample:
|
82 |
+
cfg_['mode'] = layer_type
|
83 |
+
layer = upsample(*args, **kwargs, **cfg_)
|
84 |
+
return layer
|
custom_mmpkg/custom_mmcv/cnn/bricks/wrappers.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
r"""Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/wrappers.py # noqa: E501
|
3 |
+
|
4 |
+
Wrap some nn modules to support empty tensor input. Currently, these wrappers
|
5 |
+
are mainly used in mask heads like fcn_mask_head and maskiou_heads since mask
|
6 |
+
heads are trained on only positive RoIs.
|
7 |
+
"""
|
8 |
+
import math
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.nn.modules.utils import _pair, _triple
|
13 |
+
|
14 |
+
from .registry import CONV_LAYERS, UPSAMPLE_LAYERS
|
15 |
+
|
16 |
+
if torch.__version__ == 'parrots':
|
17 |
+
TORCH_VERSION = torch.__version__
|
18 |
+
else:
|
19 |
+
# torch.__version__ could be 1.3.1+cu92, we only need the first two
|
20 |
+
# for comparison
|
21 |
+
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split('.')[:2])
|
22 |
+
|
23 |
+
|
24 |
+
def obsolete_torch_version(torch_version, version_threshold):
|
25 |
+
return torch_version == 'parrots' or torch_version <= version_threshold
|
26 |
+
|
27 |
+
|
28 |
+
class NewEmptyTensorOp(torch.autograd.Function):
|
29 |
+
|
30 |
+
@staticmethod
|
31 |
+
def forward(ctx, x, new_shape):
|
32 |
+
ctx.shape = x.shape
|
33 |
+
return x.new_empty(new_shape)
|
34 |
+
|
35 |
+
@staticmethod
|
36 |
+
def backward(ctx, grad):
|
37 |
+
shape = ctx.shape
|
38 |
+
return NewEmptyTensorOp.apply(grad, shape), None
|
39 |
+
|
40 |
+
|
41 |
+
@CONV_LAYERS.register_module('Conv', force=True)
|
42 |
+
class Conv2d(nn.Conv2d):
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
46 |
+
out_shape = [x.shape[0], self.out_channels]
|
47 |
+
for i, k, p, s, d in zip(x.shape[-2:], self.kernel_size,
|
48 |
+
self.padding, self.stride, self.dilation):
|
49 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
|
50 |
+
out_shape.append(o)
|
51 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
52 |
+
if self.training:
|
53 |
+
# produce dummy gradient to avoid DDP warning.
|
54 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
55 |
+
return empty + dummy
|
56 |
+
else:
|
57 |
+
return empty
|
58 |
+
|
59 |
+
return super().forward(x)
|
60 |
+
|
61 |
+
|
62 |
+
@CONV_LAYERS.register_module('Conv3d', force=True)
|
63 |
+
class Conv3d(nn.Conv3d):
|
64 |
+
|
65 |
+
def forward(self, x):
|
66 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
67 |
+
out_shape = [x.shape[0], self.out_channels]
|
68 |
+
for i, k, p, s, d in zip(x.shape[-3:], self.kernel_size,
|
69 |
+
self.padding, self.stride, self.dilation):
|
70 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) // s + 1
|
71 |
+
out_shape.append(o)
|
72 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
73 |
+
if self.training:
|
74 |
+
# produce dummy gradient to avoid DDP warning.
|
75 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
76 |
+
return empty + dummy
|
77 |
+
else:
|
78 |
+
return empty
|
79 |
+
|
80 |
+
return super().forward(x)
|
81 |
+
|
82 |
+
|
83 |
+
@CONV_LAYERS.register_module()
|
84 |
+
@CONV_LAYERS.register_module('deconv')
|
85 |
+
@UPSAMPLE_LAYERS.register_module('deconv', force=True)
|
86 |
+
class ConvTranspose2d(nn.ConvTranspose2d):
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
90 |
+
out_shape = [x.shape[0], self.out_channels]
|
91 |
+
for i, k, p, s, d, op in zip(x.shape[-2:], self.kernel_size,
|
92 |
+
self.padding, self.stride,
|
93 |
+
self.dilation, self.output_padding):
|
94 |
+
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
|
95 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
96 |
+
if self.training:
|
97 |
+
# produce dummy gradient to avoid DDP warning.
|
98 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
99 |
+
return empty + dummy
|
100 |
+
else:
|
101 |
+
return empty
|
102 |
+
|
103 |
+
return super().forward(x)
|
104 |
+
|
105 |
+
|
106 |
+
@CONV_LAYERS.register_module()
|
107 |
+
@CONV_LAYERS.register_module('deconv3d')
|
108 |
+
@UPSAMPLE_LAYERS.register_module('deconv3d', force=True)
|
109 |
+
class ConvTranspose3d(nn.ConvTranspose3d):
|
110 |
+
|
111 |
+
def forward(self, x):
|
112 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 4)):
|
113 |
+
out_shape = [x.shape[0], self.out_channels]
|
114 |
+
for i, k, p, s, d, op in zip(x.shape[-3:], self.kernel_size,
|
115 |
+
self.padding, self.stride,
|
116 |
+
self.dilation, self.output_padding):
|
117 |
+
out_shape.append((i - 1) * s - 2 * p + (d * (k - 1) + 1) + op)
|
118 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
119 |
+
if self.training:
|
120 |
+
# produce dummy gradient to avoid DDP warning.
|
121 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
122 |
+
return empty + dummy
|
123 |
+
else:
|
124 |
+
return empty
|
125 |
+
|
126 |
+
return super().forward(x)
|
127 |
+
|
128 |
+
|
129 |
+
class MaxPool2d(nn.MaxPool2d):
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
# PyTorch 1.9 does not support empty tensor inference yet
|
133 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
134 |
+
out_shape = list(x.shape[:2])
|
135 |
+
for i, k, p, s, d in zip(x.shape[-2:], _pair(self.kernel_size),
|
136 |
+
_pair(self.padding), _pair(self.stride),
|
137 |
+
_pair(self.dilation)):
|
138 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
|
139 |
+
o = math.ceil(o) if self.ceil_mode else math.floor(o)
|
140 |
+
out_shape.append(o)
|
141 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
142 |
+
return empty
|
143 |
+
|
144 |
+
return super().forward(x)
|
145 |
+
|
146 |
+
|
147 |
+
class MaxPool3d(nn.MaxPool3d):
|
148 |
+
|
149 |
+
def forward(self, x):
|
150 |
+
# PyTorch 1.9 does not support empty tensor inference yet
|
151 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 9)):
|
152 |
+
out_shape = list(x.shape[:2])
|
153 |
+
for i, k, p, s, d in zip(x.shape[-3:], _triple(self.kernel_size),
|
154 |
+
_triple(self.padding),
|
155 |
+
_triple(self.stride),
|
156 |
+
_triple(self.dilation)):
|
157 |
+
o = (i + 2 * p - (d * (k - 1) + 1)) / s + 1
|
158 |
+
o = math.ceil(o) if self.ceil_mode else math.floor(o)
|
159 |
+
out_shape.append(o)
|
160 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
161 |
+
return empty
|
162 |
+
|
163 |
+
return super().forward(x)
|
164 |
+
|
165 |
+
|
166 |
+
class Linear(torch.nn.Linear):
|
167 |
+
|
168 |
+
def forward(self, x):
|
169 |
+
# empty tensor forward of Linear layer is supported in Pytorch 1.6
|
170 |
+
if x.numel() == 0 and obsolete_torch_version(TORCH_VERSION, (1, 5)):
|
171 |
+
out_shape = [x.shape[0], self.out_features]
|
172 |
+
empty = NewEmptyTensorOp.apply(x, out_shape)
|
173 |
+
if self.training:
|
174 |
+
# produce dummy gradient to avoid DDP warning.
|
175 |
+
dummy = sum(x.view(-1)[0] for x in self.parameters()) * 0.0
|
176 |
+
return empty + dummy
|
177 |
+
else:
|
178 |
+
return empty
|
179 |
+
|
180 |
+
return super().forward(x)
|
custom_mmpkg/custom_mmcv/cnn/builder.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) OpenMMLab. All rights reserved.
|
2 |
+
from ..runner import Sequential
|
3 |
+
from ..utils import Registry, build_from_cfg
|
4 |
+
|
5 |
+
|
6 |
+
def build_model_from_cfg(cfg, registry, default_args=None):
|
7 |
+
"""Build a PyTorch model from config dict(s). Different from
|
8 |
+
``build_from_cfg``, if cfg is a list, a ``nn.Sequential`` will be built.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
cfg (dict, list[dict]): The config of modules, is is either a config
|
12 |
+
dict or a list of config dicts. If cfg is a list, a
|
13 |
+
the built modules will be wrapped with ``nn.Sequential``.
|
14 |
+
registry (:obj:`Registry`): A registry the module belongs to.
|
15 |
+
default_args (dict, optional): Default arguments to build the module.
|
16 |
+
Defaults to None.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
nn.Module: A built nn module.
|
20 |
+
"""
|
21 |
+
if isinstance(cfg, list):
|
22 |
+
modules = [
|
23 |
+
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
|
24 |
+
]
|
25 |
+
return Sequential(*modules)
|
26 |
+
else:
|
27 |
+
return build_from_cfg(cfg, registry, default_args)
|
28 |
+
|
29 |
+
|
30 |
+
MODELS = Registry('model', build_func=build_model_from_cfg)
|