Spaces:
Running
Running
Upload 75 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- .gradio/certificate.pem +31 -0
- README_Unified.md +128 -0
- app.py +299 -0
- app2.py +324 -0
- checkpoints/labels.txt +4 -0
- depth_anything_v2/__pycache__/dinov2.cpython-312.pyc +0 -0
- depth_anything_v2/__pycache__/dpt.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2.py +415 -0
- depth_anything_v2/dinov2_layers/__init__.py +11 -0
- depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/block.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-312.pyc +0 -0
- depth_anything_v2/dinov2_layers/attention.py +83 -0
- depth_anything_v2/dinov2_layers/block.py +252 -0
- depth_anything_v2/dinov2_layers/drop_path.py +35 -0
- depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
- depth_anything_v2/dinov2_layers/mlp.py +41 -0
- depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
- depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
- depth_anything_v2/dpt.py +221 -0
- depth_anything_v2/util/__pycache__/blocks.cpython-312.pyc +0 -0
- depth_anything_v2/util/__pycache__/transform.cpython-312.pyc +0 -0
- depth_anything_v2/util/blocks.py +148 -0
- depth_anything_v2/util/transform.py +158 -0
- environment.yml +0 -0
- environment_export.yml +182 -0
- environment_from_history.yml +30 -0
- environment_linux.yml +116 -0
- keras_model_3.h5 +3 -0
- labels.txt +4 -0
- main_app.py +540 -0
- metric_depth/README.md +114 -0
- metric_depth/assets/compare_zoedepth.png +3 -0
- metric_depth/dataset/hypersim.py +74 -0
- metric_depth/dataset/kitti.py +57 -0
- metric_depth/dataset/splits/hypersim/val.txt +0 -0
- metric_depth/dataset/splits/kitti/val.txt +0 -0
- metric_depth/dataset/splits/vkitti2/train.txt +0 -0
- metric_depth/dataset/transform.py +277 -0
- metric_depth/dataset/vkitti2.py +54 -0
- metric_depth/depth_anything_v2/dinov2.py +415 -0
- metric_depth/depth_anything_v2/dinov2_layers/__init__.py +11 -0
- metric_depth/depth_anything_v2/dinov2_layers/attention.py +83 -0
- metric_depth/depth_anything_v2/dinov2_layers/block.py +252 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
metric_depth/assets/compare_zoedepth.png filter=lfs diff=lfs merge=lfs -text
|
.gradio/certificate.pem
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
-----BEGIN CERTIFICATE-----
|
2 |
+
MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
|
3 |
+
TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
|
4 |
+
cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
|
5 |
+
WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
|
6 |
+
ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
|
7 |
+
MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
|
8 |
+
h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
|
9 |
+
0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
|
10 |
+
A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
|
11 |
+
T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
|
12 |
+
B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
|
13 |
+
B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
|
14 |
+
KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
|
15 |
+
OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
|
16 |
+
jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
|
17 |
+
qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
|
18 |
+
rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
|
19 |
+
HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
|
20 |
+
hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
|
21 |
+
ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
|
22 |
+
3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
|
23 |
+
NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
|
24 |
+
ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
|
25 |
+
TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
|
26 |
+
jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
|
27 |
+
oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
|
28 |
+
4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
|
29 |
+
mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
|
30 |
+
emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
|
31 |
+
-----END CERTIFICATE-----
|
README_Unified.md
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Medical AI Suite - Unified Interface
|
2 |
+
|
3 |
+
A comprehensive web application that combines wound classification and depth estimation capabilities in a single, modern interface.
|
4 |
+
|
5 |
+
## 🚀 Quick Start
|
6 |
+
|
7 |
+
### Option 1: Use the Launcher (Recommended)
|
8 |
+
```bash
|
9 |
+
python launcher.py
|
10 |
+
```
|
11 |
+
This will show you a menu to choose which application to run.
|
12 |
+
|
13 |
+
### Option 2: Run the Unified Interface Directly
|
14 |
+
```bash
|
15 |
+
python main_app.py
|
16 |
+
```
|
17 |
+
|
18 |
+
### Option 3: Run Individual Applications
|
19 |
+
```bash
|
20 |
+
# Wound Classification only
|
21 |
+
python app2.py
|
22 |
+
|
23 |
+
# Depth Estimation only
|
24 |
+
python app.py
|
25 |
+
```
|
26 |
+
|
27 |
+
## 🏥 Features
|
28 |
+
|
29 |
+
### Tab 1: Wound Classification
|
30 |
+
- **AI-powered wound type classification**
|
31 |
+
- **Grad-CAM visualization** - See which areas the model focuses on
|
32 |
+
- **Confidence scores** with color-coded bars
|
33 |
+
- **Real-time analysis** - Results update as you upload images
|
34 |
+
|
35 |
+
### Tab 2: Depth Estimation & 3D Visualization
|
36 |
+
- **Depth map generation** using DepthAnythingV2 model
|
37 |
+
- **Interactive 3D point cloud visualization**
|
38 |
+
- **Adjustable parameters** (focal length, point density)
|
39 |
+
- **Multiple output formats** (grayscale, raw, PLY point cloud)
|
40 |
+
- **Image slider comparison** between original and depth map
|
41 |
+
|
42 |
+
## 🎨 Interface Features
|
43 |
+
|
44 |
+
- **Modern dark theme** with gradient backgrounds
|
45 |
+
- **Tabbed navigation** between applications
|
46 |
+
- **Responsive design** that works on different screen sizes
|
47 |
+
- **Professional medical interface** styling
|
48 |
+
- **Real-time feedback** and progress indicators
|
49 |
+
|
50 |
+
## 📁 File Structure
|
51 |
+
|
52 |
+
```
|
53 |
+
├── main_app.py # Unified interface (NEW)
|
54 |
+
├── launcher.py # Application launcher (NEW)
|
55 |
+
├── app.py # Original depth estimation app
|
56 |
+
├── app2.py # Original wound classification app
|
57 |
+
├── checkpoints/
|
58 |
+
│ ├── keras_model.h5 # Wound classification model
|
59 |
+
│ └── depth_anything_v2_vitl.pth # Depth estimation model
|
60 |
+
├── labels.txt # Wound classification labels
|
61 |
+
└── depth_anything_v2/ # Depth model implementation
|
62 |
+
```
|
63 |
+
|
64 |
+
## 🔧 Requirements
|
65 |
+
|
66 |
+
The unified interface requires all the same dependencies as the individual applications:
|
67 |
+
|
68 |
+
- `gradio`
|
69 |
+
- `tensorflow`
|
70 |
+
- `torch`
|
71 |
+
- `opencv-python`
|
72 |
+
- `pillow`
|
73 |
+
- `numpy`
|
74 |
+
- `matplotlib`
|
75 |
+
- `plotly`
|
76 |
+
- `open3d`
|
77 |
+
- `gradio-imageslider`
|
78 |
+
|
79 |
+
## 🌐 Access
|
80 |
+
|
81 |
+
Once launched, the interface will be available at:
|
82 |
+
- **Local**: http://localhost:7860
|
83 |
+
- **Public**: A public link will be provided when the server starts
|
84 |
+
|
85 |
+
## 💡 Usage Tips
|
86 |
+
|
87 |
+
### Wound Classification
|
88 |
+
1. Upload a clear image of the wound
|
89 |
+
2. The model will automatically classify the wound type
|
90 |
+
3. View the Grad-CAM heatmap to see which areas influenced the decision
|
91 |
+
4. Check confidence scores for all possible classifications
|
92 |
+
|
93 |
+
### Depth Estimation
|
94 |
+
1. Upload an image for depth analysis
|
95 |
+
2. Adjust the number of 3D points (higher = more detailed but slower)
|
96 |
+
3. Set focal length parameters if you know your camera specs
|
97 |
+
4. Click "Compute Depth" to generate results
|
98 |
+
5. Download depth maps and point clouds as needed
|
99 |
+
6. Explore the interactive 3D visualization
|
100 |
+
|
101 |
+
## 🛠️ Troubleshooting
|
102 |
+
|
103 |
+
### Model Loading Issues
|
104 |
+
If models fail to load, the interface will show appropriate error messages and continue to function with limited capabilities.
|
105 |
+
|
106 |
+
### Performance
|
107 |
+
- For large images, consider reducing the number of 3D points
|
108 |
+
- Depth estimation works best with good lighting and clear subjects
|
109 |
+
- Wound classification works best with well-lit, focused images
|
110 |
+
|
111 |
+
### Browser Compatibility
|
112 |
+
The interface works best with modern browsers (Chrome, Firefox, Safari, Edge).
|
113 |
+
|
114 |
+
## 🔄 Navigation
|
115 |
+
|
116 |
+
You can easily switch between the two main functionalities using the tabs at the top of the interface. Each tab maintains its own state, so you can work on both applications simultaneously.
|
117 |
+
|
118 |
+
## 📞 Support
|
119 |
+
|
120 |
+
If you encounter any issues:
|
121 |
+
1. Check that all required model files are present
|
122 |
+
2. Ensure all dependencies are installed
|
123 |
+
3. Try running individual applications first to isolate issues
|
124 |
+
4. Check the console output for error messages
|
125 |
+
|
126 |
+
---
|
127 |
+
|
128 |
+
**Enjoy using the Medical AI Suite! 🏥✨**
|
app.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import gradio as gr
|
3 |
+
import matplotlib
|
4 |
+
import numpy as np
|
5 |
+
from PIL import Image
|
6 |
+
import torch
|
7 |
+
import tempfile
|
8 |
+
from gradio_imageslider import ImageSlider
|
9 |
+
import plotly.graph_objects as go
|
10 |
+
import plotly.express as px
|
11 |
+
import open3d as o3d
|
12 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
13 |
+
|
14 |
+
css = """
|
15 |
+
#img-display-container {
|
16 |
+
max-height: 100vh;
|
17 |
+
}
|
18 |
+
#img-display-input {
|
19 |
+
max-height: 80vh;
|
20 |
+
}
|
21 |
+
#img-display-output {
|
22 |
+
max-height: 80vh;
|
23 |
+
}
|
24 |
+
#download {
|
25 |
+
height: 62px;
|
26 |
+
}
|
27 |
+
h1 {
|
28 |
+
text-align: center;
|
29 |
+
font-size: 3rem;
|
30 |
+
font-weight: bold;
|
31 |
+
margin: 2rem 0;
|
32 |
+
color: #2c3e50;
|
33 |
+
}
|
34 |
+
"""
|
35 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
36 |
+
model_configs = {
|
37 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
38 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
39 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
40 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
41 |
+
}
|
42 |
+
encoder = 'vitl'
|
43 |
+
model = DepthAnythingV2(**model_configs[encoder])
|
44 |
+
state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
|
45 |
+
model.load_state_dict(state_dict)
|
46 |
+
model = model.to(DEVICE).eval()
|
47 |
+
|
48 |
+
title = "Depth Estimation, 3D Visualization"
|
49 |
+
description = """Official demo for **Depth Estimation, 3D Visualization**."""
|
50 |
+
|
51 |
+
def predict_depth(image):
|
52 |
+
return model.infer_image(image)
|
53 |
+
|
54 |
+
def calculate_max_points(image):
|
55 |
+
"""Calculate maximum points based on image dimensions (3x pixel count)"""
|
56 |
+
if image is None:
|
57 |
+
return 10000 # Default value
|
58 |
+
h, w = image.shape[:2]
|
59 |
+
max_points = h * w * 3
|
60 |
+
# Ensure minimum and reasonable maximum values
|
61 |
+
return max(1000, min(max_points, 1000000))
|
62 |
+
|
63 |
+
def update_slider_on_image_upload(image):
|
64 |
+
"""Update the points slider when an image is uploaded"""
|
65 |
+
max_points = calculate_max_points(image)
|
66 |
+
default_value = min(10000, max_points // 10) # 10% of max points as default
|
67 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
68 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
69 |
+
|
70 |
+
def create_3d_depth_visualization(image, depth_map, max_points=10000):
|
71 |
+
"""Create an interactive 3D visualization of the depth map"""
|
72 |
+
h, w = depth_map.shape
|
73 |
+
|
74 |
+
# Downsample to avoid too many points for performance
|
75 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
76 |
+
|
77 |
+
# Create coordinate grids
|
78 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
79 |
+
depth_values = depth_map[::step, ::step]
|
80 |
+
|
81 |
+
# Flatten arrays
|
82 |
+
x_flat = x_coords.flatten()
|
83 |
+
y_flat = y_coords.flatten()
|
84 |
+
z_flat = depth_values.flatten()
|
85 |
+
|
86 |
+
# Get corresponding image colors
|
87 |
+
image_colors = image[::step, ::step, :]
|
88 |
+
colors_flat = image_colors.reshape(-1, 3)
|
89 |
+
|
90 |
+
# Create 3D scatter plot
|
91 |
+
fig = go.Figure(data=[go.Scatter3d(
|
92 |
+
x=x_flat,
|
93 |
+
y=y_flat,
|
94 |
+
z=z_flat,
|
95 |
+
mode='markers',
|
96 |
+
marker=dict(
|
97 |
+
size=2,
|
98 |
+
color=colors_flat,
|
99 |
+
opacity=0.8
|
100 |
+
),
|
101 |
+
hovertemplate='<b>Position:</b> (%{x:.0f}, %{y:.0f})<br>' +
|
102 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
103 |
+
'<extra></extra>'
|
104 |
+
)])
|
105 |
+
|
106 |
+
fig.update_layout(
|
107 |
+
title="3D Depth Visualization (Hover to see depth values)",
|
108 |
+
scene=dict(
|
109 |
+
xaxis_title="X (pixels)",
|
110 |
+
yaxis_title="Y (pixels)",
|
111 |
+
zaxis_title="Depth",
|
112 |
+
camera=dict(
|
113 |
+
eye=dict(x=1.5, y=1.5, z=1.5)
|
114 |
+
)
|
115 |
+
),
|
116 |
+
width=600,
|
117 |
+
height=500
|
118 |
+
)
|
119 |
+
|
120 |
+
return fig
|
121 |
+
|
122 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000):
|
123 |
+
"""Create a point cloud from depth map using camera intrinsics"""
|
124 |
+
h, w = depth_map.shape
|
125 |
+
|
126 |
+
# Downsample to avoid too many points for performance
|
127 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
128 |
+
|
129 |
+
# Create mesh grid for camera coordinates
|
130 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
131 |
+
|
132 |
+
# Convert to camera coordinates (normalized by focal length)
|
133 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
134 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
135 |
+
|
136 |
+
# Get depth values
|
137 |
+
depth_values = depth_map[::step, ::step]
|
138 |
+
|
139 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
140 |
+
x_3d = x_cam * depth_values
|
141 |
+
y_3d = y_cam * depth_values
|
142 |
+
z_3d = depth_values
|
143 |
+
|
144 |
+
# Flatten arrays
|
145 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
146 |
+
|
147 |
+
# Get corresponding image colors
|
148 |
+
image_colors = image[::step, ::step, :]
|
149 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
150 |
+
|
151 |
+
# Create Open3D point cloud
|
152 |
+
pcd = o3d.geometry.PointCloud()
|
153 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
154 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
155 |
+
|
156 |
+
return pcd
|
157 |
+
|
158 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
159 |
+
"""Create an enhanced 3D visualization using proper camera projection"""
|
160 |
+
h, w = depth_map.shape
|
161 |
+
|
162 |
+
# Downsample to avoid too many points for performance
|
163 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
164 |
+
|
165 |
+
# Create mesh grid for camera coordinates
|
166 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
167 |
+
|
168 |
+
# Convert to camera coordinates (normalized by focal length)
|
169 |
+
focal_length = 470.4 # Default focal length
|
170 |
+
x_cam = (x_coords - w / 2) / focal_length
|
171 |
+
y_cam = (y_coords - h / 2) / focal_length
|
172 |
+
|
173 |
+
# Get depth values
|
174 |
+
depth_values = depth_map[::step, ::step]
|
175 |
+
|
176 |
+
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
|
177 |
+
x_3d = x_cam * depth_values
|
178 |
+
y_3d = y_cam * depth_values
|
179 |
+
z_3d = depth_values
|
180 |
+
|
181 |
+
# Flatten arrays
|
182 |
+
x_flat = x_3d.flatten()
|
183 |
+
y_flat = y_3d.flatten()
|
184 |
+
z_flat = z_3d.flatten()
|
185 |
+
|
186 |
+
# Get corresponding image colors
|
187 |
+
image_colors = image[::step, ::step, :]
|
188 |
+
colors_flat = image_colors.reshape(-1, 3)
|
189 |
+
|
190 |
+
# Create 3D scatter plot with proper camera projection
|
191 |
+
fig = go.Figure(data=[go.Scatter3d(
|
192 |
+
x=x_flat,
|
193 |
+
y=y_flat,
|
194 |
+
z=z_flat,
|
195 |
+
mode='markers',
|
196 |
+
marker=dict(
|
197 |
+
size=1.5,
|
198 |
+
color=colors_flat,
|
199 |
+
opacity=0.9
|
200 |
+
),
|
201 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
202 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
203 |
+
'<extra></extra>'
|
204 |
+
)])
|
205 |
+
|
206 |
+
fig.update_layout(
|
207 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
208 |
+
scene=dict(
|
209 |
+
xaxis_title="X (meters)",
|
210 |
+
yaxis_title="Y (meters)",
|
211 |
+
zaxis_title="Z (meters)",
|
212 |
+
camera=dict(
|
213 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
214 |
+
center=dict(x=0, y=0, z=0),
|
215 |
+
up=dict(x=0, y=0, z=1)
|
216 |
+
),
|
217 |
+
aspectmode='data'
|
218 |
+
),
|
219 |
+
width=700,
|
220 |
+
height=600
|
221 |
+
)
|
222 |
+
|
223 |
+
return fig
|
224 |
+
|
225 |
+
with gr.Blocks(css=css) as demo:
|
226 |
+
gr.HTML(f"<h1>{title}</h1>")
|
227 |
+
gr.Markdown(description)
|
228 |
+
gr.Markdown("### Depth Prediction demo")
|
229 |
+
|
230 |
+
with gr.Row():
|
231 |
+
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
232 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
233 |
+
|
234 |
+
with gr.Row():
|
235 |
+
submit = gr.Button(value="Compute Depth", variant="primary")
|
236 |
+
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
237 |
+
label="Number of 3D points (upload image to update max)")
|
238 |
+
|
239 |
+
with gr.Row():
|
240 |
+
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
241 |
+
label="Focal Length X (pixels)")
|
242 |
+
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
243 |
+
label="Focal Length Y (pixels)")
|
244 |
+
|
245 |
+
with gr.Row():
|
246 |
+
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
247 |
+
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
248 |
+
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
249 |
+
|
250 |
+
# 3D Visualization
|
251 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
252 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
253 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
254 |
+
|
255 |
+
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
|
256 |
+
|
257 |
+
def on_submit(image, num_points, focal_x, focal_y):
|
258 |
+
original_image = image.copy()
|
259 |
+
|
260 |
+
h, w = image.shape[:2]
|
261 |
+
|
262 |
+
depth = predict_depth(image[:, :, ::-1])
|
263 |
+
|
264 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
265 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
266 |
+
raw_depth.save(tmp_raw_depth.name)
|
267 |
+
|
268 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
269 |
+
depth = depth.astype(np.uint8)
|
270 |
+
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
|
271 |
+
|
272 |
+
gray_depth = Image.fromarray(depth)
|
273 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
274 |
+
gray_depth.save(tmp_gray_depth.name)
|
275 |
+
|
276 |
+
# Create point cloud
|
277 |
+
pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points)
|
278 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
279 |
+
o3d.io.write_point_cloud(tmp_pointcloud.name, pcd)
|
280 |
+
|
281 |
+
# Create enhanced 3D visualization
|
282 |
+
depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points)
|
283 |
+
|
284 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
285 |
+
|
286 |
+
# Update slider when image is uploaded
|
287 |
+
input_image.change(
|
288 |
+
fn=update_slider_on_image_upload,
|
289 |
+
inputs=[input_image],
|
290 |
+
outputs=[points_slider]
|
291 |
+
)
|
292 |
+
|
293 |
+
submit.click(on_submit, inputs=[input_image, points_slider, focal_length_x, focal_length_y],
|
294 |
+
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot])
|
295 |
+
|
296 |
+
|
297 |
+
|
298 |
+
if __name__ == '__main__':
|
299 |
+
demo.queue().launch()
|
app2.py
ADDED
@@ -0,0 +1,324 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.models import load_model
|
5 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
6 |
+
from tensorflow.keras import backend as K
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
import io
|
10 |
+
import cv2
|
11 |
+
|
12 |
+
# --- Load model and labels ---
|
13 |
+
model = load_model("checkpoints/keras_model.h5")
|
14 |
+
with open("labels.txt", "r") as f:
|
15 |
+
class_labels = [line.strip() for line in f]
|
16 |
+
|
17 |
+
# --- Preprocess input ---
|
18 |
+
def preprocess_input(img):
|
19 |
+
img = img.resize((224, 224))
|
20 |
+
arr = keras_image.img_to_array(img)
|
21 |
+
arr = arr / 255.0
|
22 |
+
return np.expand_dims(arr, axis=0)
|
23 |
+
|
24 |
+
# --- Enhanced Grad-CAM implementation for Keras ---
|
25 |
+
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
|
26 |
+
try:
|
27 |
+
# Try to find the specified layer
|
28 |
+
target_layer = model.get_layer(last_conv_layer_name)
|
29 |
+
except:
|
30 |
+
# Fallback: find any convolutional layer
|
31 |
+
for layer in model.layers:
|
32 |
+
if 'conv' in layer.name.lower():
|
33 |
+
target_layer = layer
|
34 |
+
break
|
35 |
+
else:
|
36 |
+
return None
|
37 |
+
|
38 |
+
grad_model = tf.keras.models.Model(
|
39 |
+
[model.inputs], [target_layer.output, model.output]
|
40 |
+
)
|
41 |
+
|
42 |
+
with tf.GradientTape() as tape:
|
43 |
+
conv_outputs, predictions = grad_model(img_array)
|
44 |
+
loss = predictions[:, class_index]
|
45 |
+
|
46 |
+
grads = tape.gradient(loss, conv_outputs)[0]
|
47 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
48 |
+
conv_outputs = conv_outputs[0]
|
49 |
+
|
50 |
+
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
|
51 |
+
heatmap = np.maximum(heatmap, 0)
|
52 |
+
heatmap = heatmap / np.max(heatmap + K.epsilon())
|
53 |
+
return heatmap.numpy()
|
54 |
+
|
55 |
+
# --- Enhanced Overlay heatmap on image ---
|
56 |
+
def overlay_gradcam(original_img, heatmap):
|
57 |
+
if heatmap is None:
|
58 |
+
return original_img
|
59 |
+
|
60 |
+
# Resize heatmap
|
61 |
+
heatmap = cv2.resize(heatmap, original_img.size)
|
62 |
+
|
63 |
+
# Normalize safely
|
64 |
+
heatmap = np.maximum(heatmap, 0)
|
65 |
+
if np.max(heatmap) != 0:
|
66 |
+
heatmap /= np.max(heatmap)
|
67 |
+
heatmap = np.uint8(255 * heatmap)
|
68 |
+
|
69 |
+
# Apply JET colormap for better medical visualization
|
70 |
+
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
71 |
+
|
72 |
+
# Convert PIL to array
|
73 |
+
original_array = np.array(original_img.convert("RGB"))
|
74 |
+
|
75 |
+
# Enhanced blend with better contrast
|
76 |
+
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
|
77 |
+
|
78 |
+
return Image.fromarray(superimposed_img)
|
79 |
+
|
80 |
+
# --- Enhanced Prediction Function ---
|
81 |
+
def classify_and_explain(img):
|
82 |
+
if img is None:
|
83 |
+
return None, {}, "No image provided"
|
84 |
+
|
85 |
+
img_array = preprocess_input(img)
|
86 |
+
predictions = model.predict(img_array, verbose=0)[0]
|
87 |
+
pred_idx = int(np.argmax(predictions))
|
88 |
+
pred_class = class_labels[pred_idx]
|
89 |
+
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
|
90 |
+
|
91 |
+
# Enhanced Grad-CAM
|
92 |
+
try:
|
93 |
+
heatmap = get_gradcam_heatmap(img_array, model, pred_idx)
|
94 |
+
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
|
95 |
+
except Exception as e:
|
96 |
+
print(f"Grad-CAM error: {e}")
|
97 |
+
gradcam_img = img.resize((224, 224)) # fallback image
|
98 |
+
|
99 |
+
return gradcam_img, confidence_dict
|
100 |
+
|
101 |
+
# --- Custom CSS for Dark Mode Medical Interface ---
|
102 |
+
css = """
|
103 |
+
.gradio-container {
|
104 |
+
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
|
105 |
+
background: #1a1a1a;
|
106 |
+
min-height: 100vh;
|
107 |
+
padding: 20px;
|
108 |
+
color: #ffffff;
|
109 |
+
}
|
110 |
+
|
111 |
+
.main-header {
|
112 |
+
text-align: center;
|
113 |
+
color: white;
|
114 |
+
margin-bottom: 2rem;
|
115 |
+
padding: 2rem 0;
|
116 |
+
}
|
117 |
+
|
118 |
+
.main-header h1 {
|
119 |
+
font-size: 2.5rem;
|
120 |
+
margin-bottom: 0.5rem;
|
121 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
|
122 |
+
color: #ffffff;
|
123 |
+
}
|
124 |
+
|
125 |
+
.confidence-bar {
|
126 |
+
background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%);
|
127 |
+
height: 25px;
|
128 |
+
border-radius: 12px;
|
129 |
+
margin: 8px 0;
|
130 |
+
transition: all 0.3s ease;
|
131 |
+
box-shadow: 0 2px 4px rgba(0,0,0,0.3);
|
132 |
+
}
|
133 |
+
|
134 |
+
.confidence-container {
|
135 |
+
margin: 15px 0;
|
136 |
+
padding: 20px;
|
137 |
+
border-radius: 12px;
|
138 |
+
background: rgba(255,255,255,0.1);
|
139 |
+
backdrop-filter: blur(10px);
|
140 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
|
141 |
+
border: 1px solid rgba(255,255,255,0.1);
|
142 |
+
}
|
143 |
+
|
144 |
+
.input-section, .output-section {
|
145 |
+
background: rgba(255,255,255,0.05);
|
146 |
+
padding: 25px;
|
147 |
+
border-radius: 15px;
|
148 |
+
margin: 15px;
|
149 |
+
backdrop-filter: blur(10px);
|
150 |
+
box-shadow: 0 8px 32px rgba(0,0,0,0.3);
|
151 |
+
border: 1px solid rgba(255,255,255,0.1);
|
152 |
+
}
|
153 |
+
|
154 |
+
.section-title {
|
155 |
+
color: #ffffff;
|
156 |
+
font-size: 1.3rem;
|
157 |
+
font-weight: 600;
|
158 |
+
margin-bottom: 15px;
|
159 |
+
border-bottom: 2px solid #3498db;
|
160 |
+
padding-bottom: 8px;
|
161 |
+
}
|
162 |
+
|
163 |
+
.gradio-button {
|
164 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
165 |
+
border: none;
|
166 |
+
color: white;
|
167 |
+
padding: 12px 24px;
|
168 |
+
border-radius: 25px;
|
169 |
+
font-weight: 600;
|
170 |
+
transition: all 0.3s ease;
|
171 |
+
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
|
172 |
+
}
|
173 |
+
|
174 |
+
.gradio-button:hover {
|
175 |
+
transform: translateY(-2px);
|
176 |
+
box-shadow: 0 6px 20px rgba(0,0,0,0.4);
|
177 |
+
}
|
178 |
+
|
179 |
+
.gradio-image {
|
180 |
+
border-radius: 12px;
|
181 |
+
box-shadow: 0 4px 15px rgba(0,0,0,0.3);
|
182 |
+
border: 1px solid rgba(255,255,255,0.1);
|
183 |
+
}
|
184 |
+
|
185 |
+
.gradio-textbox, .gradio-number {
|
186 |
+
border-radius: 8px;
|
187 |
+
border: 2px solid #333333;
|
188 |
+
padding: 12px;
|
189 |
+
font-size: 1rem;
|
190 |
+
background: rgba(255,255,255,0.05);
|
191 |
+
color: #ffffff;
|
192 |
+
}
|
193 |
+
|
194 |
+
.gradio-textbox:focus, .gradio-number:focus {
|
195 |
+
border-color: #3498db;
|
196 |
+
box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25);
|
197 |
+
}
|
198 |
+
|
199 |
+
.gradio-label {
|
200 |
+
color: #ffffff !important;
|
201 |
+
}
|
202 |
+
|
203 |
+
.heatmap-container {
|
204 |
+
background: rgba(255,255,255,0.05);
|
205 |
+
padding: 15px;
|
206 |
+
border-radius: 12px;
|
207 |
+
border: 1px solid rgba(255,255,255,0.1);
|
208 |
+
margin: 10px 0;
|
209 |
+
}
|
210 |
+
|
211 |
+
.prediction-container {
|
212 |
+
background: rgba(52,152,219,0.1);
|
213 |
+
padding: 20px;
|
214 |
+
border-radius: 12px;
|
215 |
+
border-left: 5px solid #3498db;
|
216 |
+
margin: 15px 0;
|
217 |
+
}
|
218 |
+
"""
|
219 |
+
|
220 |
+
# --- Function to create confidence bars HTML ---
|
221 |
+
def create_confidence_bars(confidence_dict):
|
222 |
+
html_content = "<div class='confidence-container'>"
|
223 |
+
for class_name, confidence in confidence_dict.items():
|
224 |
+
percentage = confidence * 100
|
225 |
+
# Color coding based on confidence
|
226 |
+
if percentage > 70:
|
227 |
+
color = "#28a745" # Green for high confidence
|
228 |
+
elif percentage > 40:
|
229 |
+
color = "#ffc107" # Yellow for medium confidence
|
230 |
+
else:
|
231 |
+
color = "#dc3545" # Red for low confidence
|
232 |
+
|
233 |
+
html_content += f"""
|
234 |
+
<div style='margin: 12px 0;'>
|
235 |
+
<div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
|
236 |
+
<span style='font-weight: bold; color: {color};'>{class_name}</span>
|
237 |
+
<span style='font-weight: bold; color: {color};'>{percentage:.1f}%</span>
|
238 |
+
</div>
|
239 |
+
<div class='confidence-bar' style='width: {percentage}%; background: {color};'></div>
|
240 |
+
</div>
|
241 |
+
"""
|
242 |
+
html_content += "</div>"
|
243 |
+
return html_content
|
244 |
+
|
245 |
+
# --- Enhanced Prediction Function with Dark Mode Interface ---
|
246 |
+
def enhanced_classify_and_explain(img):
|
247 |
+
if img is None:
|
248 |
+
return None, "No image provided", 0, ""
|
249 |
+
|
250 |
+
gradcam_img, confidence_dict = classify_and_explain(img)
|
251 |
+
|
252 |
+
# Get predicted class and confidence
|
253 |
+
pred_class = max(confidence_dict, key=confidence_dict.get)
|
254 |
+
confidence = confidence_dict[pred_class]
|
255 |
+
|
256 |
+
# Create confidence bars HTML
|
257 |
+
confidence_bars_html = create_confidence_bars(confidence_dict)
|
258 |
+
|
259 |
+
return gradcam_img, pred_class, confidence, confidence_bars_html
|
260 |
+
|
261 |
+
# --- Enhanced Gradio Interface ---
|
262 |
+
with gr.Blocks(css=css, title="Wound Classification") as demo:
|
263 |
+
gr.HTML("""
|
264 |
+
<div class="main-header">
|
265 |
+
<h1>Wound Classification</h1>
|
266 |
+
</div>
|
267 |
+
""")
|
268 |
+
|
269 |
+
with gr.Row():
|
270 |
+
with gr.Column(scale=1):
|
271 |
+
gr.HTML("<div class='section-title'>Input Image</div>")
|
272 |
+
input_image = gr.Image(
|
273 |
+
label="Upload wound image",
|
274 |
+
type="pil",
|
275 |
+
height=350,
|
276 |
+
container=True
|
277 |
+
)
|
278 |
+
|
279 |
+
with gr.Column(scale=1):
|
280 |
+
gr.HTML("<div class='section-title'>Analysis Results</div>")
|
281 |
+
|
282 |
+
# Prediction results
|
283 |
+
prediction_output = gr.Textbox(
|
284 |
+
label="Predicted Wound Type",
|
285 |
+
interactive=False,
|
286 |
+
container=True
|
287 |
+
)
|
288 |
+
|
289 |
+
confidence_output = gr.Number(
|
290 |
+
label="Confidence Score",
|
291 |
+
interactive=False,
|
292 |
+
container=True
|
293 |
+
)
|
294 |
+
|
295 |
+
# Confidence bars for all classes
|
296 |
+
confidence_bars = gr.HTML(
|
297 |
+
label="Confidence Scores by Class",
|
298 |
+
container=True
|
299 |
+
)
|
300 |
+
|
301 |
+
with gr.Row():
|
302 |
+
with gr.Column():
|
303 |
+
gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
|
304 |
+
cam_output = gr.Image(
|
305 |
+
label="Grad-CAM Heatmap - Shows which areas the model focused on",
|
306 |
+
height=350,
|
307 |
+
container=True
|
308 |
+
)
|
309 |
+
|
310 |
+
# Event handlers
|
311 |
+
input_image.change(
|
312 |
+
fn=enhanced_classify_and_explain,
|
313 |
+
inputs=[input_image],
|
314 |
+
outputs=[cam_output, prediction_output, confidence_output, confidence_bars]
|
315 |
+
)
|
316 |
+
|
317 |
+
# --- Launch the enhanced interface ---
|
318 |
+
if __name__ == "__main__":
|
319 |
+
demo.launch(
|
320 |
+
server_name="0.0.0.0",
|
321 |
+
server_port=7860,
|
322 |
+
share=True,
|
323 |
+
show_error=True
|
324 |
+
)
|
checkpoints/labels.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0 Burns
|
2 |
+
1 Surgical Wound
|
3 |
+
2 Traumatic Wound
|
4 |
+
3 Diabetic Foot Ulcer
|
depth_anything_v2/__pycache__/dinov2.cpython-312.pyc
ADDED
Binary file (18.7 kB). View file
|
|
depth_anything_v2/__pycache__/dpt.cpython-312.pyc
ADDED
Binary file (10.6 kB). View file
|
|
depth_anything_v2/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
depth_anything_v2/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (441 Bytes). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-312.pyc
ADDED
Binary file (3.95 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/block.cpython-312.pyc
ADDED
Binary file (13.1 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-312.pyc
ADDED
Binary file (1.65 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-312.pyc
ADDED
Binary file (1.42 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-312.pyc
ADDED
Binary file (1.85 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-312.pyc
ADDED
Binary file (4.06 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-312.pyc
ADDED
Binary file (2.84 kB). View file
|
|
depth_anything_v2/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
depth_anything_v2/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
depth_anything_v2/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
depth_anything_v2/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
depth_anything_v2/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
depth_anything_v2/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
depth_anything_v2/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
depth_anything_v2/dpt.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from torchvision.transforms import Compose
|
6 |
+
|
7 |
+
from .dinov2 import DINOv2
|
8 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
9 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
10 |
+
|
11 |
+
|
12 |
+
def _make_fusion_block(features, use_bn, size=None):
|
13 |
+
return FeatureFusionBlock(
|
14 |
+
features,
|
15 |
+
nn.ReLU(False),
|
16 |
+
deconv=False,
|
17 |
+
bn=use_bn,
|
18 |
+
expand=False,
|
19 |
+
align_corners=True,
|
20 |
+
size=size,
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
class ConvBlock(nn.Module):
|
25 |
+
def __init__(self, in_feature, out_feature):
|
26 |
+
super().__init__()
|
27 |
+
|
28 |
+
self.conv_block = nn.Sequential(
|
29 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
30 |
+
nn.BatchNorm2d(out_feature),
|
31 |
+
nn.ReLU(True)
|
32 |
+
)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return self.conv_block(x)
|
36 |
+
|
37 |
+
|
38 |
+
class DPTHead(nn.Module):
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
in_channels,
|
42 |
+
features=256,
|
43 |
+
use_bn=False,
|
44 |
+
out_channels=[256, 512, 1024, 1024],
|
45 |
+
use_clstoken=False
|
46 |
+
):
|
47 |
+
super(DPTHead, self).__init__()
|
48 |
+
|
49 |
+
self.use_clstoken = use_clstoken
|
50 |
+
|
51 |
+
self.projects = nn.ModuleList([
|
52 |
+
nn.Conv2d(
|
53 |
+
in_channels=in_channels,
|
54 |
+
out_channels=out_channel,
|
55 |
+
kernel_size=1,
|
56 |
+
stride=1,
|
57 |
+
padding=0,
|
58 |
+
) for out_channel in out_channels
|
59 |
+
])
|
60 |
+
|
61 |
+
self.resize_layers = nn.ModuleList([
|
62 |
+
nn.ConvTranspose2d(
|
63 |
+
in_channels=out_channels[0],
|
64 |
+
out_channels=out_channels[0],
|
65 |
+
kernel_size=4,
|
66 |
+
stride=4,
|
67 |
+
padding=0),
|
68 |
+
nn.ConvTranspose2d(
|
69 |
+
in_channels=out_channels[1],
|
70 |
+
out_channels=out_channels[1],
|
71 |
+
kernel_size=2,
|
72 |
+
stride=2,
|
73 |
+
padding=0),
|
74 |
+
nn.Identity(),
|
75 |
+
nn.Conv2d(
|
76 |
+
in_channels=out_channels[3],
|
77 |
+
out_channels=out_channels[3],
|
78 |
+
kernel_size=3,
|
79 |
+
stride=2,
|
80 |
+
padding=1)
|
81 |
+
])
|
82 |
+
|
83 |
+
if use_clstoken:
|
84 |
+
self.readout_projects = nn.ModuleList()
|
85 |
+
for _ in range(len(self.projects)):
|
86 |
+
self.readout_projects.append(
|
87 |
+
nn.Sequential(
|
88 |
+
nn.Linear(2 * in_channels, in_channels),
|
89 |
+
nn.GELU()))
|
90 |
+
|
91 |
+
self.scratch = _make_scratch(
|
92 |
+
out_channels,
|
93 |
+
features,
|
94 |
+
groups=1,
|
95 |
+
expand=False,
|
96 |
+
)
|
97 |
+
|
98 |
+
self.scratch.stem_transpose = None
|
99 |
+
|
100 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
101 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
102 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
103 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
104 |
+
|
105 |
+
head_features_1 = features
|
106 |
+
head_features_2 = 32
|
107 |
+
|
108 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
109 |
+
self.scratch.output_conv2 = nn.Sequential(
|
110 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
111 |
+
nn.ReLU(True),
|
112 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
113 |
+
nn.ReLU(True),
|
114 |
+
nn.Identity(),
|
115 |
+
)
|
116 |
+
|
117 |
+
def forward(self, out_features, patch_h, patch_w):
|
118 |
+
out = []
|
119 |
+
for i, x in enumerate(out_features):
|
120 |
+
if self.use_clstoken:
|
121 |
+
x, cls_token = x[0], x[1]
|
122 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
123 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
124 |
+
else:
|
125 |
+
x = x[0]
|
126 |
+
|
127 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
128 |
+
|
129 |
+
x = self.projects[i](x)
|
130 |
+
x = self.resize_layers[i](x)
|
131 |
+
|
132 |
+
out.append(x)
|
133 |
+
|
134 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
135 |
+
|
136 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
137 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
138 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
139 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
140 |
+
|
141 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
142 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
143 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
144 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
145 |
+
|
146 |
+
out = self.scratch.output_conv1(path_1)
|
147 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
148 |
+
out = self.scratch.output_conv2(out)
|
149 |
+
|
150 |
+
return out
|
151 |
+
|
152 |
+
|
153 |
+
class DepthAnythingV2(nn.Module):
|
154 |
+
def __init__(
|
155 |
+
self,
|
156 |
+
encoder='vitl',
|
157 |
+
features=256,
|
158 |
+
out_channels=[256, 512, 1024, 1024],
|
159 |
+
use_bn=False,
|
160 |
+
use_clstoken=False
|
161 |
+
):
|
162 |
+
super(DepthAnythingV2, self).__init__()
|
163 |
+
|
164 |
+
self.intermediate_layer_idx = {
|
165 |
+
'vits': [2, 5, 8, 11],
|
166 |
+
'vitb': [2, 5, 8, 11],
|
167 |
+
'vitl': [4, 11, 17, 23],
|
168 |
+
'vitg': [9, 19, 29, 39]
|
169 |
+
}
|
170 |
+
|
171 |
+
self.encoder = encoder
|
172 |
+
self.pretrained = DINOv2(model_name=encoder)
|
173 |
+
|
174 |
+
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
178 |
+
|
179 |
+
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
180 |
+
|
181 |
+
depth = self.depth_head(features, patch_h, patch_w)
|
182 |
+
depth = F.relu(depth)
|
183 |
+
|
184 |
+
return depth.squeeze(1)
|
185 |
+
|
186 |
+
@torch.no_grad()
|
187 |
+
def infer_image(self, raw_image, input_size=518):
|
188 |
+
image, (h, w) = self.image2tensor(raw_image, input_size)
|
189 |
+
|
190 |
+
depth = self.forward(image)
|
191 |
+
|
192 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
193 |
+
|
194 |
+
return depth.cpu().numpy()
|
195 |
+
|
196 |
+
def image2tensor(self, raw_image, input_size=518):
|
197 |
+
transform = Compose([
|
198 |
+
Resize(
|
199 |
+
width=input_size,
|
200 |
+
height=input_size,
|
201 |
+
resize_target=False,
|
202 |
+
keep_aspect_ratio=True,
|
203 |
+
ensure_multiple_of=14,
|
204 |
+
resize_method='lower_bound',
|
205 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
206 |
+
),
|
207 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
208 |
+
PrepareForNet(),
|
209 |
+
])
|
210 |
+
|
211 |
+
h, w = raw_image.shape[:2]
|
212 |
+
|
213 |
+
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
|
214 |
+
|
215 |
+
image = transform({'image': image})['image']
|
216 |
+
image = torch.from_numpy(image).unsqueeze(0)
|
217 |
+
|
218 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
219 |
+
image = image.to(DEVICE)
|
220 |
+
|
221 |
+
return image, (h, w)
|
depth_anything_v2/util/__pycache__/blocks.cpython-312.pyc
ADDED
Binary file (5.55 kB). View file
|
|
depth_anything_v2/util/__pycache__/transform.cpython-312.pyc
ADDED
Binary file (7.45 kB). View file
|
|
depth_anything_v2/util/blocks.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
|
5 |
+
scratch = nn.Module()
|
6 |
+
|
7 |
+
out_shape1 = out_shape
|
8 |
+
out_shape2 = out_shape
|
9 |
+
out_shape3 = out_shape
|
10 |
+
if len(in_shape) >= 4:
|
11 |
+
out_shape4 = out_shape
|
12 |
+
|
13 |
+
if expand:
|
14 |
+
out_shape1 = out_shape
|
15 |
+
out_shape2 = out_shape * 2
|
16 |
+
out_shape3 = out_shape * 4
|
17 |
+
if len(in_shape) >= 4:
|
18 |
+
out_shape4 = out_shape * 8
|
19 |
+
|
20 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
21 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
22 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
23 |
+
if len(in_shape) >= 4:
|
24 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
|
25 |
+
|
26 |
+
return scratch
|
27 |
+
|
28 |
+
|
29 |
+
class ResidualConvUnit(nn.Module):
|
30 |
+
"""Residual convolution module.
|
31 |
+
"""
|
32 |
+
|
33 |
+
def __init__(self, features, activation, bn):
|
34 |
+
"""Init.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
features (int): number of features
|
38 |
+
"""
|
39 |
+
super().__init__()
|
40 |
+
|
41 |
+
self.bn = bn
|
42 |
+
|
43 |
+
self.groups=1
|
44 |
+
|
45 |
+
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
46 |
+
|
47 |
+
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
|
48 |
+
|
49 |
+
if self.bn == True:
|
50 |
+
self.bn1 = nn.BatchNorm2d(features)
|
51 |
+
self.bn2 = nn.BatchNorm2d(features)
|
52 |
+
|
53 |
+
self.activation = activation
|
54 |
+
|
55 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
"""Forward pass.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
x (tensor): input
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
tensor: output
|
65 |
+
"""
|
66 |
+
|
67 |
+
out = self.activation(x)
|
68 |
+
out = self.conv1(out)
|
69 |
+
if self.bn == True:
|
70 |
+
out = self.bn1(out)
|
71 |
+
|
72 |
+
out = self.activation(out)
|
73 |
+
out = self.conv2(out)
|
74 |
+
if self.bn == True:
|
75 |
+
out = self.bn2(out)
|
76 |
+
|
77 |
+
if self.groups > 1:
|
78 |
+
out = self.conv_merge(out)
|
79 |
+
|
80 |
+
return self.skip_add.add(out, x)
|
81 |
+
|
82 |
+
|
83 |
+
class FeatureFusionBlock(nn.Module):
|
84 |
+
"""Feature fusion block.
|
85 |
+
"""
|
86 |
+
|
87 |
+
def __init__(
|
88 |
+
self,
|
89 |
+
features,
|
90 |
+
activation,
|
91 |
+
deconv=False,
|
92 |
+
bn=False,
|
93 |
+
expand=False,
|
94 |
+
align_corners=True,
|
95 |
+
size=None
|
96 |
+
):
|
97 |
+
"""Init.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
features (int): number of features
|
101 |
+
"""
|
102 |
+
super(FeatureFusionBlock, self).__init__()
|
103 |
+
|
104 |
+
self.deconv = deconv
|
105 |
+
self.align_corners = align_corners
|
106 |
+
|
107 |
+
self.groups=1
|
108 |
+
|
109 |
+
self.expand = expand
|
110 |
+
out_features = features
|
111 |
+
if self.expand == True:
|
112 |
+
out_features = features // 2
|
113 |
+
|
114 |
+
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
|
115 |
+
|
116 |
+
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
|
117 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
|
118 |
+
|
119 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
120 |
+
|
121 |
+
self.size=size
|
122 |
+
|
123 |
+
def forward(self, *xs, size=None):
|
124 |
+
"""Forward pass.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
tensor: output
|
128 |
+
"""
|
129 |
+
output = xs[0]
|
130 |
+
|
131 |
+
if len(xs) == 2:
|
132 |
+
res = self.resConfUnit1(xs[1])
|
133 |
+
output = self.skip_add.add(output, res)
|
134 |
+
|
135 |
+
output = self.resConfUnit2(output)
|
136 |
+
|
137 |
+
if (size is None) and (self.size is None):
|
138 |
+
modifier = {"scale_factor": 2}
|
139 |
+
elif size is None:
|
140 |
+
modifier = {"size": self.size}
|
141 |
+
else:
|
142 |
+
modifier = {"size": size}
|
143 |
+
|
144 |
+
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
|
145 |
+
|
146 |
+
output = self.out_conv(output)
|
147 |
+
|
148 |
+
return output
|
depth_anything_v2/util/transform.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
|
5 |
+
class Resize(object):
|
6 |
+
"""Resize sample to given size (width, height).
|
7 |
+
"""
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
width,
|
12 |
+
height,
|
13 |
+
resize_target=True,
|
14 |
+
keep_aspect_ratio=False,
|
15 |
+
ensure_multiple_of=1,
|
16 |
+
resize_method="lower_bound",
|
17 |
+
image_interpolation_method=cv2.INTER_AREA,
|
18 |
+
):
|
19 |
+
"""Init.
|
20 |
+
|
21 |
+
Args:
|
22 |
+
width (int): desired output width
|
23 |
+
height (int): desired output height
|
24 |
+
resize_target (bool, optional):
|
25 |
+
True: Resize the full sample (image, mask, target).
|
26 |
+
False: Resize image only.
|
27 |
+
Defaults to True.
|
28 |
+
keep_aspect_ratio (bool, optional):
|
29 |
+
True: Keep the aspect ratio of the input sample.
|
30 |
+
Output sample might not have the given width and height, and
|
31 |
+
resize behaviour depends on the parameter 'resize_method'.
|
32 |
+
Defaults to False.
|
33 |
+
ensure_multiple_of (int, optional):
|
34 |
+
Output width and height is constrained to be multiple of this parameter.
|
35 |
+
Defaults to 1.
|
36 |
+
resize_method (str, optional):
|
37 |
+
"lower_bound": Output will be at least as large as the given size.
|
38 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
39 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
40 |
+
Defaults to "lower_bound".
|
41 |
+
"""
|
42 |
+
self.__width = width
|
43 |
+
self.__height = height
|
44 |
+
|
45 |
+
self.__resize_target = resize_target
|
46 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
47 |
+
self.__multiple_of = ensure_multiple_of
|
48 |
+
self.__resize_method = resize_method
|
49 |
+
self.__image_interpolation_method = image_interpolation_method
|
50 |
+
|
51 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
52 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
53 |
+
|
54 |
+
if max_val is not None and y > max_val:
|
55 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
56 |
+
|
57 |
+
if y < min_val:
|
58 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
59 |
+
|
60 |
+
return y
|
61 |
+
|
62 |
+
def get_size(self, width, height):
|
63 |
+
# determine new height and width
|
64 |
+
scale_height = self.__height / height
|
65 |
+
scale_width = self.__width / width
|
66 |
+
|
67 |
+
if self.__keep_aspect_ratio:
|
68 |
+
if self.__resize_method == "lower_bound":
|
69 |
+
# scale such that output size is lower bound
|
70 |
+
if scale_width > scale_height:
|
71 |
+
# fit width
|
72 |
+
scale_height = scale_width
|
73 |
+
else:
|
74 |
+
# fit height
|
75 |
+
scale_width = scale_height
|
76 |
+
elif self.__resize_method == "upper_bound":
|
77 |
+
# scale such that output size is upper bound
|
78 |
+
if scale_width < scale_height:
|
79 |
+
# fit width
|
80 |
+
scale_height = scale_width
|
81 |
+
else:
|
82 |
+
# fit height
|
83 |
+
scale_width = scale_height
|
84 |
+
elif self.__resize_method == "minimal":
|
85 |
+
# scale as least as possbile
|
86 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
87 |
+
# fit width
|
88 |
+
scale_height = scale_width
|
89 |
+
else:
|
90 |
+
# fit height
|
91 |
+
scale_width = scale_height
|
92 |
+
else:
|
93 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
94 |
+
|
95 |
+
if self.__resize_method == "lower_bound":
|
96 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
|
97 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
|
98 |
+
elif self.__resize_method == "upper_bound":
|
99 |
+
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
|
100 |
+
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
|
101 |
+
elif self.__resize_method == "minimal":
|
102 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
103 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
104 |
+
else:
|
105 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
106 |
+
|
107 |
+
return (new_width, new_height)
|
108 |
+
|
109 |
+
def __call__(self, sample):
|
110 |
+
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
|
111 |
+
|
112 |
+
# resize sample
|
113 |
+
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
|
114 |
+
|
115 |
+
if self.__resize_target:
|
116 |
+
if "depth" in sample:
|
117 |
+
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
|
118 |
+
|
119 |
+
if "mask" in sample:
|
120 |
+
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
|
121 |
+
|
122 |
+
return sample
|
123 |
+
|
124 |
+
|
125 |
+
class NormalizeImage(object):
|
126 |
+
"""Normlize image by given mean and std.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(self, mean, std):
|
130 |
+
self.__mean = mean
|
131 |
+
self.__std = std
|
132 |
+
|
133 |
+
def __call__(self, sample):
|
134 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
135 |
+
|
136 |
+
return sample
|
137 |
+
|
138 |
+
|
139 |
+
class PrepareForNet(object):
|
140 |
+
"""Prepare sample for usage as network input.
|
141 |
+
"""
|
142 |
+
|
143 |
+
def __init__(self):
|
144 |
+
pass
|
145 |
+
|
146 |
+
def __call__(self, sample):
|
147 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
148 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
149 |
+
|
150 |
+
if "depth" in sample:
|
151 |
+
depth = sample["depth"].astype(np.float32)
|
152 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
153 |
+
|
154 |
+
if "mask" in sample:
|
155 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
156 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
157 |
+
|
158 |
+
return sample
|
environment.yml
ADDED
Binary file (6.93 kB). View file
|
|
environment_export.yml
ADDED
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: depth_copy
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- _libgcc_mutex=0.1=conda_forge
|
7 |
+
- _openmp_mutex=4.5=2_gnu
|
8 |
+
- bzip2=1.0.8=h4bc722e_7
|
9 |
+
- ca-certificates=2025.6.15=hbd8a1cb_0
|
10 |
+
- ld_impl_linux-64=2.44=h1423503_0
|
11 |
+
- libexpat=2.7.0=h5888daf_0
|
12 |
+
- libffi=3.4.6=h2dba641_1
|
13 |
+
- libgcc=15.1.0=h767d61c_3
|
14 |
+
- libgcc-ng=15.1.0=h69a702a_3
|
15 |
+
- libgomp=15.1.0=h767d61c_3
|
16 |
+
- liblzma=5.8.1=hb9d3cd8_2
|
17 |
+
- libnsl=2.0.1=hb9d3cd8_1
|
18 |
+
- libsqlite=3.50.2=h6cd9bfd_0
|
19 |
+
- libuuid=2.38.1=h0b41bf4_0
|
20 |
+
- libxcrypt=4.4.36=hd590300_1
|
21 |
+
- libzlib=1.3.1=hb9d3cd8_2
|
22 |
+
- ncurses=6.5=h2d0b736_3
|
23 |
+
- openssl=3.5.1=h7b32b05_0
|
24 |
+
- pip=25.1.1=pyh8b19718_0
|
25 |
+
- python=3.12.11=h9e4cc4f_0_cpython
|
26 |
+
- readline=8.2=h8c095d6_2
|
27 |
+
- setuptools=80.9.0=pyhff2d567_0
|
28 |
+
- tk=8.6.13=noxft_hd72426e_102
|
29 |
+
- wheel=0.45.1=pyhd8ed1ab_1
|
30 |
+
- pip:
|
31 |
+
- absl-py==2.3.1
|
32 |
+
- addict==2.4.0
|
33 |
+
- aiofiles==24.1.0
|
34 |
+
- annotated-types==0.7.0
|
35 |
+
- anyio==4.9.0
|
36 |
+
- asttokens==3.0.0
|
37 |
+
- astunparse==1.6.3
|
38 |
+
- attrs==25.3.0
|
39 |
+
- blinker==1.9.0
|
40 |
+
- certifi==2025.6.15
|
41 |
+
- charset-normalizer==3.4.2
|
42 |
+
- click==8.2.1
|
43 |
+
- colorama==0.4.6
|
44 |
+
- comm==0.2.2
|
45 |
+
- configargparse==1.7.1
|
46 |
+
- contourpy==1.3.2
|
47 |
+
- cycler==0.12.1
|
48 |
+
- dash==3.1.1
|
49 |
+
- decorator==5.2.1
|
50 |
+
- executing==2.2.0
|
51 |
+
- fastapi==0.115.14
|
52 |
+
- fastjsonschema==2.21.1
|
53 |
+
- ffmpy==0.6.0
|
54 |
+
- filelock==3.18.0
|
55 |
+
- flask==3.1.1
|
56 |
+
- flatbuffers==25.2.10
|
57 |
+
- fonttools==4.58.4
|
58 |
+
- fsspec==2025.5.1
|
59 |
+
- gast==0.6.0
|
60 |
+
- google-pasta==0.2.0
|
61 |
+
- gradio==5.35.0
|
62 |
+
- gradio-client==1.10.4
|
63 |
+
- gradio-imageslider==0.0.20
|
64 |
+
- groovy==0.1.2
|
65 |
+
- grpcio==1.73.1
|
66 |
+
- h11==0.16.0
|
67 |
+
- h5py==3.14.0
|
68 |
+
- hf-xet==1.1.5
|
69 |
+
- httpcore==1.0.9
|
70 |
+
- httpx==0.28.1
|
71 |
+
- huggingface-hub==0.33.2
|
72 |
+
- idna==3.10
|
73 |
+
- importlib-metadata==8.7.0
|
74 |
+
- ipython==9.4.0
|
75 |
+
- ipython-pygments-lexers==1.1.1
|
76 |
+
- ipywidgets==8.1.7
|
77 |
+
- itsdangerous==2.2.0
|
78 |
+
- jedi==0.19.2
|
79 |
+
- jinja2==3.1.6
|
80 |
+
- joblib==1.5.1
|
81 |
+
- jsonschema==4.24.0
|
82 |
+
- jsonschema-specifications==2025.4.1
|
83 |
+
- jupyter-core==5.8.1
|
84 |
+
- jupyterlab-widgets==3.0.15
|
85 |
+
- keras==3.10.0
|
86 |
+
- kiwisolver==1.4.8
|
87 |
+
- libclang==18.1.1
|
88 |
+
- markdown==3.8.2
|
89 |
+
- markdown-it-py==3.0.0
|
90 |
+
- markupsafe==3.0.2
|
91 |
+
- matplotlib==3.10.3
|
92 |
+
- matplotlib-inline==0.1.7
|
93 |
+
- mdurl==0.1.2
|
94 |
+
- ml-dtypes==0.5.1
|
95 |
+
- mpmath==1.3.0
|
96 |
+
- namex==0.1.0
|
97 |
+
- narwhals==1.45.0
|
98 |
+
- nbformat==5.10.4
|
99 |
+
- nest-asyncio==1.6.0
|
100 |
+
- networkx==3.5
|
101 |
+
- numpy==2.1.3
|
102 |
+
- nvidia-cublas-cu12==12.6.4.1
|
103 |
+
- nvidia-cuda-cupti-cu12==12.6.80
|
104 |
+
- nvidia-cuda-nvrtc-cu12==12.6.77
|
105 |
+
- nvidia-cuda-runtime-cu12==12.6.77
|
106 |
+
- nvidia-cudnn-cu12==9.5.1.17
|
107 |
+
- nvidia-cufft-cu12==11.3.0.4
|
108 |
+
- nvidia-cufile-cu12==1.11.1.6
|
109 |
+
- nvidia-curand-cu12==10.3.7.77
|
110 |
+
- nvidia-cusolver-cu12==11.7.1.2
|
111 |
+
- nvidia-cusparse-cu12==12.5.4.2
|
112 |
+
- nvidia-cusparselt-cu12==0.6.3
|
113 |
+
- nvidia-nccl-cu12==2.26.2
|
114 |
+
- nvidia-nvjitlink-cu12==12.6.85
|
115 |
+
- nvidia-nvtx-cu12==12.6.77
|
116 |
+
- open3d==0.19.0
|
117 |
+
- opencv-python==4.11.0.86
|
118 |
+
- opt-einsum==3.4.0
|
119 |
+
- optree==0.16.0
|
120 |
+
- orjson==3.10.18
|
121 |
+
- packaging==25.0
|
122 |
+
- pandas==2.3.0
|
123 |
+
- parso==0.8.4
|
124 |
+
- pexpect==4.9.0
|
125 |
+
- pillow==11.3.0
|
126 |
+
- platformdirs==4.3.8
|
127 |
+
- plotly==6.2.0
|
128 |
+
- prompt-toolkit==3.0.51
|
129 |
+
- protobuf==5.29.5
|
130 |
+
- ptyprocess==0.7.0
|
131 |
+
- pure-eval==0.2.3
|
132 |
+
- pydantic==2.11.7
|
133 |
+
- pydantic-core==2.33.2
|
134 |
+
- pydub==0.25.1
|
135 |
+
- pygments==2.19.2
|
136 |
+
- pyparsing==3.2.3
|
137 |
+
- pyquaternion==0.9.9
|
138 |
+
- python-dateutil==2.9.0.post0
|
139 |
+
- python-multipart==0.0.20
|
140 |
+
- pytz==2025.2
|
141 |
+
- pyyaml==6.0.2
|
142 |
+
- referencing==0.36.2
|
143 |
+
- requests==2.32.4
|
144 |
+
- retrying==1.4.0
|
145 |
+
- rich==14.0.0
|
146 |
+
- rpds-py==0.26.0
|
147 |
+
- ruff==0.12.1
|
148 |
+
- safehttpx==0.1.6
|
149 |
+
- scikit-learn==1.7.0
|
150 |
+
- scipy==1.16.0
|
151 |
+
- semantic-version==2.10.0
|
152 |
+
- shellingham==1.5.4
|
153 |
+
- six==1.17.0
|
154 |
+
- sniffio==1.3.1
|
155 |
+
- stack-data==0.6.3
|
156 |
+
- starlette==0.46.2
|
157 |
+
- sympy==1.14.0
|
158 |
+
- tensorboard==2.19.0
|
159 |
+
- tensorboard-data-server==0.7.2
|
160 |
+
- tensorflow==2.19.0
|
161 |
+
- termcolor==3.1.0
|
162 |
+
- threadpoolctl==3.6.0
|
163 |
+
- tomlkit==0.13.3
|
164 |
+
- torch==2.7.1
|
165 |
+
- torchaudio==2.7.1
|
166 |
+
- torchvision==0.22.1
|
167 |
+
- tqdm==4.67.1
|
168 |
+
- traitlets==5.14.3
|
169 |
+
- triton==3.3.1
|
170 |
+
- typer==0.16.0
|
171 |
+
- typing-extensions==4.14.0
|
172 |
+
- typing-inspection==0.4.1
|
173 |
+
- tzdata==2025.2
|
174 |
+
- urllib3==2.5.0
|
175 |
+
- uvicorn==0.35.0
|
176 |
+
- wcwidth==0.2.13
|
177 |
+
- websockets==15.0.1
|
178 |
+
- werkzeug==3.1.3
|
179 |
+
- widgetsnbextension==4.0.14
|
180 |
+
- wrapt==1.17.2
|
181 |
+
- zipp==3.23.0
|
182 |
+
prefix: /home/uphen/anaconda3/envs/depth_copy
|
environment_from_history.yml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: depth_copy
|
2 |
+
channels:
|
3 |
+
- defaults
|
4 |
+
dependencies:
|
5 |
+
- conda-forge/linux-64::_libgcc_mutex==0.1=conda_forge
|
6 |
+
- conda-forge/noarch::ca-certificates==2025.6.15=hbd8a1cb_0
|
7 |
+
- conda-forge/linux-64::ld_impl_linux-64==2.44=h1423503_0
|
8 |
+
- conda-forge/linux-64::libgomp==15.1.0=h767d61c_3
|
9 |
+
- conda-forge/noarch::tzdata==2025b=h78e105d_0
|
10 |
+
- conda-forge/linux-64::_openmp_mutex==4.5=2_gnu
|
11 |
+
- conda-forge/linux-64::libgcc==15.1.0=h767d61c_3
|
12 |
+
- conda-forge/linux-64::libexpat==2.7.0=h5888daf_0
|
13 |
+
- conda-forge/linux-64::libffi==3.4.6=h2dba641_1
|
14 |
+
- conda-forge/linux-64::libgcc-ng==15.1.0=h69a702a_3
|
15 |
+
- conda-forge/linux-64::liblzma==5.8.1=hb9d3cd8_2
|
16 |
+
- conda-forge/linux-64::libnsl==2.0.1=hb9d3cd8_1
|
17 |
+
- conda-forge/linux-64::libzlib==1.3.1=hb9d3cd8_2
|
18 |
+
- conda-forge/linux-64::ncurses==6.5=h2d0b736_3
|
19 |
+
- conda-forge/linux-64::openssl==3.5.1=h7b32b05_0
|
20 |
+
- conda-forge/linux-64::bzip2==1.0.8=h4bc722e_7
|
21 |
+
- conda-forge/linux-64::libsqlite==3.50.2=h6cd9bfd_0
|
22 |
+
- conda-forge/linux-64::libuuid==2.38.1=h0b41bf4_0
|
23 |
+
- conda-forge/linux-64::libxcrypt==4.4.36=hd590300_1
|
24 |
+
- conda-forge/linux-64::readline==8.2=h8c095d6_2
|
25 |
+
- conda-forge/linux-64::tk==8.6.13=noxft_hd72426e_102
|
26 |
+
- conda-forge/linux-64::python==3.12.11=h9e4cc4f_0_cpython
|
27 |
+
- conda-forge/noarch::setuptools==80.9.0=pyhff2d567_0
|
28 |
+
- conda-forge/noarch::wheel==0.45.1=pyhd8ed1ab_1
|
29 |
+
- conda-forge/noarch::pip==25.1.1=pyh8b19718_0
|
30 |
+
prefix: /home/uphen/anaconda3/envs/depth_copy
|
environment_linux.yml
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: depth
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- pytorch
|
5 |
+
- defaults
|
6 |
+
dependencies:
|
7 |
+
- python=3.12
|
8 |
+
- pip
|
9 |
+
- pip:
|
10 |
+
- aiofiles==24.1.0
|
11 |
+
- annotated-types==0.7.0
|
12 |
+
- anyio==4.9.0
|
13 |
+
- asttokens==3.0.0
|
14 |
+
- attrs==25.3.0
|
15 |
+
- blinker==1.9.0
|
16 |
+
- certifi==2025.6.15
|
17 |
+
- charset-normalizer==3.4.2
|
18 |
+
- click==8.2.1
|
19 |
+
- colorama==0.4.6
|
20 |
+
- comm==0.2.2
|
21 |
+
- configargparse==1.7.1
|
22 |
+
- contourpy==1.3.2
|
23 |
+
- cycler==0.12.1
|
24 |
+
- dash==3.1.1
|
25 |
+
- decorator==5.2.1
|
26 |
+
- executing==2.2.0
|
27 |
+
- fastapi==0.115.14
|
28 |
+
- fastjsonschema==2.21.1
|
29 |
+
- ffmpy==0.6.0
|
30 |
+
- filelock==3.18.0
|
31 |
+
- flask==3.1.1
|
32 |
+
- fonttools==4.58.4
|
33 |
+
- fsspec==2025.5.1
|
34 |
+
- gradio==5.35.0
|
35 |
+
- gradio-client==1.10.4
|
36 |
+
- gradio-imageslider==0.0.20
|
37 |
+
- groovy==0.1.2
|
38 |
+
- h11==0.16.0
|
39 |
+
- httpcore==1.0.9
|
40 |
+
- httpx==0.28.1
|
41 |
+
- huggingface-hub==0.33.2
|
42 |
+
- idna==3.10
|
43 |
+
- importlib-metadata==8.7.0
|
44 |
+
- ipython==9.4.0
|
45 |
+
- ipython-pygments-lexers==1.1.1
|
46 |
+
- ipywidgets==8.1.7
|
47 |
+
- itsdangerous==2.2.0
|
48 |
+
- jedi==0.19.2
|
49 |
+
- jinja2==3.1.6
|
50 |
+
- jsonschema==4.24.0
|
51 |
+
- jsonschema-specifications==2025.4.1
|
52 |
+
- jupyter-core==5.8.1
|
53 |
+
- jupyterlab-widgets==3.0.15
|
54 |
+
- kiwisolver==1.4.8
|
55 |
+
- markdown-it-py==3.0.0
|
56 |
+
- markupsafe==3.0.2
|
57 |
+
- matplotlib==3.10.3
|
58 |
+
- matplotlib-inline==0.1.7
|
59 |
+
- mdurl==0.1.2
|
60 |
+
- mpmath==1.3.0
|
61 |
+
- narwhals==1.45.0
|
62 |
+
- nbformat==5.10.4
|
63 |
+
- nest-asyncio==1.6.0
|
64 |
+
- networkx==3.5
|
65 |
+
- numpy==2.3.1
|
66 |
+
- open3d==0.19.0
|
67 |
+
- opencv-python==4.11.0.86
|
68 |
+
- orjson==3.10.18
|
69 |
+
- packaging==25.0
|
70 |
+
- pandas==2.3.0
|
71 |
+
- parso==0.8.4
|
72 |
+
- pillow==11.3.0
|
73 |
+
- platformdirs==4.3.8
|
74 |
+
- plotly==6.2.0
|
75 |
+
- prompt-toolkit==3.0.51
|
76 |
+
- pure-eval==0.2.3
|
77 |
+
- pydantic==2.11.7
|
78 |
+
- pydantic-core==2.33.2
|
79 |
+
- pydub==0.25.1
|
80 |
+
- pygments==2.19.2
|
81 |
+
- pyparsing==3.2.3
|
82 |
+
- python-dateutil==2.9.0.post0
|
83 |
+
- python-multipart==0.0.20
|
84 |
+
- pytz==2025.2
|
85 |
+
- pyyaml==6.0.2
|
86 |
+
- referencing==0.36.2
|
87 |
+
- requests==2.32.4
|
88 |
+
- retrying==1.4.0
|
89 |
+
- rich==14.0.0
|
90 |
+
- rpds-py==0.26.0
|
91 |
+
- ruff==0.12.1
|
92 |
+
- safehttpx==0.1.6
|
93 |
+
- semantic-version==2.10.0
|
94 |
+
- shellingham==1.5.4
|
95 |
+
- six==1.17.0
|
96 |
+
- sniffio==1.3.1
|
97 |
+
- stack-data==0.6.3
|
98 |
+
- starlette==0.46.2
|
99 |
+
- sympy==1.14.0
|
100 |
+
- tomlkit==0.13.3
|
101 |
+
- torch==2.7.1
|
102 |
+
- torchaudio==2.7.1
|
103 |
+
- torchvision==0.22.1
|
104 |
+
- tqdm==4.67.1
|
105 |
+
- traitlets==5.14.3
|
106 |
+
- typer==0.16.0
|
107 |
+
- typing-extensions==4.14.0
|
108 |
+
- typing-inspection==0.4.1
|
109 |
+
- tzdata==2025.2
|
110 |
+
- urllib3==2.5.0
|
111 |
+
- uvicorn==0.35.0
|
112 |
+
- wcwidth==0.2.13
|
113 |
+
- websockets==15.0.1
|
114 |
+
- werkzeug==3.1.3
|
115 |
+
- widgetsnbextension==4.0.14
|
116 |
+
- zipp==3.23.0
|
keras_model_3.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4669439786977b2098b89ad88387538d8085bd129e1e6b9de7f7821aa8baa3fa
|
3 |
+
size 2453440
|
labels.txt
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0 Burns
|
2 |
+
1 Surgical Wou...
|
3 |
+
2 Traumatic Wo...
|
4 |
+
3 Diabetic Foo...
|
main_app.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import tensorflow as tf
|
4 |
+
from tensorflow.keras.models import load_model
|
5 |
+
from tensorflow.keras.preprocessing import image as keras_image
|
6 |
+
from tensorflow.keras import backend as K
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from PIL import Image
|
9 |
+
import io
|
10 |
+
import cv2
|
11 |
+
import glob
|
12 |
+
import matplotlib
|
13 |
+
import torch
|
14 |
+
import tempfile
|
15 |
+
from gradio_imageslider import ImageSlider
|
16 |
+
import plotly.graph_objects as go
|
17 |
+
import plotly.express as px
|
18 |
+
import open3d as o3d
|
19 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
20 |
+
|
21 |
+
# --- Load models ---
|
22 |
+
# Wound classification model
|
23 |
+
try:
|
24 |
+
wound_model = load_model("checkpoints/keras_model.h5")
|
25 |
+
with open("labels.txt", "r") as f:
|
26 |
+
class_labels = [line.strip() for line in f]
|
27 |
+
except:
|
28 |
+
wound_model = None
|
29 |
+
class_labels = ["No model found"]
|
30 |
+
|
31 |
+
# Depth estimation model
|
32 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
33 |
+
model_configs = {
|
34 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
35 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
36 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
37 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
38 |
+
}
|
39 |
+
encoder = 'vitl'
|
40 |
+
try:
|
41 |
+
depth_model = DepthAnythingV2(**model_configs[encoder])
|
42 |
+
state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
|
43 |
+
depth_model.load_state_dict(state_dict)
|
44 |
+
depth_model = depth_model.to(DEVICE).eval()
|
45 |
+
except:
|
46 |
+
depth_model = None
|
47 |
+
|
48 |
+
# --- Wound Classification Functions ---
|
49 |
+
def preprocess_input(img):
|
50 |
+
img = img.resize((224, 224))
|
51 |
+
arr = keras_image.img_to_array(img)
|
52 |
+
arr = arr / 255.0
|
53 |
+
return np.expand_dims(arr, axis=0)
|
54 |
+
|
55 |
+
def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
|
56 |
+
try:
|
57 |
+
target_layer = model.get_layer(last_conv_layer_name)
|
58 |
+
except:
|
59 |
+
for layer in model.layers:
|
60 |
+
if 'conv' in layer.name.lower():
|
61 |
+
target_layer = layer
|
62 |
+
break
|
63 |
+
else:
|
64 |
+
return None
|
65 |
+
|
66 |
+
grad_model = tf.keras.models.Model(
|
67 |
+
[model.inputs], [target_layer.output, model.output]
|
68 |
+
)
|
69 |
+
|
70 |
+
with tf.GradientTape() as tape:
|
71 |
+
conv_outputs, predictions = grad_model(img_array)
|
72 |
+
loss = predictions[:, class_index]
|
73 |
+
|
74 |
+
grads = tape.gradient(loss, conv_outputs)
|
75 |
+
if grads is None:
|
76 |
+
return None
|
77 |
+
|
78 |
+
grads = grads[0]
|
79 |
+
pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
|
80 |
+
conv_outputs = conv_outputs[0]
|
81 |
+
|
82 |
+
heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
|
83 |
+
heatmap = np.maximum(heatmap, 0)
|
84 |
+
heatmap = heatmap / np.max(heatmap + K.epsilon())
|
85 |
+
return heatmap.numpy()
|
86 |
+
|
87 |
+
def overlay_gradcam(original_img, heatmap):
|
88 |
+
if heatmap is None:
|
89 |
+
return original_img
|
90 |
+
|
91 |
+
heatmap = cv2.resize(heatmap, original_img.size)
|
92 |
+
heatmap = np.maximum(heatmap, 0)
|
93 |
+
if np.max(heatmap) != 0:
|
94 |
+
heatmap /= np.max(heatmap)
|
95 |
+
heatmap = np.uint8(255 * heatmap)
|
96 |
+
|
97 |
+
heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
98 |
+
original_array = np.array(original_img.convert("RGB"))
|
99 |
+
superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
|
100 |
+
|
101 |
+
return Image.fromarray(superimposed_img)
|
102 |
+
|
103 |
+
def classify_and_explain(img):
|
104 |
+
if img is None or wound_model is None:
|
105 |
+
return None, {}, "No image provided or model not available"
|
106 |
+
|
107 |
+
img_array = preprocess_input(img)
|
108 |
+
predictions = wound_model.predict(img_array, verbose=0)[0]
|
109 |
+
pred_idx = int(np.argmax(predictions))
|
110 |
+
pred_class = class_labels[pred_idx]
|
111 |
+
confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
|
112 |
+
|
113 |
+
try:
|
114 |
+
heatmap = get_gradcam_heatmap(img_array, wound_model, pred_idx)
|
115 |
+
gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Grad-CAM error: {e}")
|
118 |
+
gradcam_img = img.resize((224, 224))
|
119 |
+
|
120 |
+
return gradcam_img, confidence_dict
|
121 |
+
|
122 |
+
def create_confidence_bars(confidence_dict):
|
123 |
+
html_content = "<div class='confidence-container'>"
|
124 |
+
for class_name, confidence in confidence_dict.items():
|
125 |
+
percentage = confidence * 100
|
126 |
+
if percentage > 70:
|
127 |
+
css_class = "confidence-high"
|
128 |
+
elif percentage > 40:
|
129 |
+
css_class = "confidence-medium"
|
130 |
+
else:
|
131 |
+
css_class = "confidence-low"
|
132 |
+
|
133 |
+
html_content += f"""
|
134 |
+
<div style='margin: 12px 0;'>
|
135 |
+
<div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
|
136 |
+
<span style='font-weight: bold;'>{class_name}</span>
|
137 |
+
<span style='font-weight: bold;'>{percentage:.1f}%</span>
|
138 |
+
</div>
|
139 |
+
<div class='confidence-bar {css_class}' style='width: {percentage}%;'></div>
|
140 |
+
</div>
|
141 |
+
"""
|
142 |
+
html_content += "</div>"
|
143 |
+
return html_content
|
144 |
+
|
145 |
+
def enhanced_classify_and_explain(img):
|
146 |
+
if img is None:
|
147 |
+
return None, "No image provided", 0, ""
|
148 |
+
|
149 |
+
gradcam_img, confidence_dict = classify_and_explain(img)
|
150 |
+
|
151 |
+
if isinstance(confidence_dict, str): # Error case
|
152 |
+
return None, confidence_dict, 0, ""
|
153 |
+
|
154 |
+
pred_class = max(confidence_dict, key=confidence_dict.get)
|
155 |
+
confidence = confidence_dict[pred_class]
|
156 |
+
confidence_bars_html = create_confidence_bars(confidence_dict)
|
157 |
+
|
158 |
+
return gradcam_img, pred_class, confidence, confidence_bars_html
|
159 |
+
|
160 |
+
# --- Depth Estimation Functions ---
|
161 |
+
def predict_depth(image):
|
162 |
+
if depth_model is None:
|
163 |
+
return None
|
164 |
+
return depth_model.infer_image(image)
|
165 |
+
|
166 |
+
def calculate_max_points(image):
|
167 |
+
if image is None:
|
168 |
+
return 10000
|
169 |
+
h, w = image.shape[:2]
|
170 |
+
max_points = h * w * 3
|
171 |
+
return max(1000, min(max_points, 1000000))
|
172 |
+
|
173 |
+
def update_slider_on_image_upload(image):
|
174 |
+
max_points = calculate_max_points(image)
|
175 |
+
default_value = min(10000, max_points // 10)
|
176 |
+
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
|
177 |
+
label=f"Number of 3D points (max: {max_points:,})")
|
178 |
+
|
179 |
+
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000):
|
180 |
+
h, w = depth_map.shape
|
181 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
182 |
+
|
183 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
184 |
+
x_cam = (x_coords - w / 2) / focal_length_x
|
185 |
+
y_cam = (y_coords - h / 2) / focal_length_y
|
186 |
+
|
187 |
+
depth_values = depth_map[::step, ::step]
|
188 |
+
x_3d = x_cam * depth_values
|
189 |
+
y_3d = y_cam * depth_values
|
190 |
+
z_3d = depth_values
|
191 |
+
|
192 |
+
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
|
193 |
+
image_colors = image[::step, ::step, :]
|
194 |
+
colors = image_colors.reshape(-1, 3) / 255.0
|
195 |
+
|
196 |
+
pcd = o3d.geometry.PointCloud()
|
197 |
+
pcd.points = o3d.utility.Vector3dVector(points)
|
198 |
+
pcd.colors = o3d.utility.Vector3dVector(colors)
|
199 |
+
|
200 |
+
return pcd
|
201 |
+
|
202 |
+
def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
|
203 |
+
h, w = depth_map.shape
|
204 |
+
step = max(1, int(np.sqrt(h * w / max_points)))
|
205 |
+
|
206 |
+
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
|
207 |
+
focal_length = 470.4
|
208 |
+
x_cam = (x_coords - w / 2) / focal_length
|
209 |
+
y_cam = (y_coords - h / 2) / focal_length
|
210 |
+
|
211 |
+
depth_values = depth_map[::step, ::step]
|
212 |
+
x_3d = x_cam * depth_values
|
213 |
+
y_3d = y_cam * depth_values
|
214 |
+
z_3d = depth_values
|
215 |
+
|
216 |
+
x_flat = x_3d.flatten()
|
217 |
+
y_flat = y_3d.flatten()
|
218 |
+
z_flat = z_3d.flatten()
|
219 |
+
|
220 |
+
image_colors = image[::step, ::step, :]
|
221 |
+
colors_flat = image_colors.reshape(-1, 3)
|
222 |
+
|
223 |
+
fig = go.Figure(data=[go.Scatter3d(
|
224 |
+
x=x_flat,
|
225 |
+
y=y_flat,
|
226 |
+
z=z_flat,
|
227 |
+
mode='markers',
|
228 |
+
marker=dict(
|
229 |
+
size=1.5,
|
230 |
+
color=colors_flat,
|
231 |
+
opacity=0.9
|
232 |
+
),
|
233 |
+
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
|
234 |
+
'<b>Depth:</b> %{z:.2f}<br>' +
|
235 |
+
'<extra></extra>'
|
236 |
+
)])
|
237 |
+
|
238 |
+
fig.update_layout(
|
239 |
+
title="3D Point Cloud Visualization (Camera Projection)",
|
240 |
+
scene=dict(
|
241 |
+
xaxis_title="X (meters)",
|
242 |
+
yaxis_title="Y (meters)",
|
243 |
+
zaxis_title="Z (meters)",
|
244 |
+
camera=dict(
|
245 |
+
eye=dict(x=2.0, y=2.0, z=2.0),
|
246 |
+
center=dict(x=0, y=0, z=0),
|
247 |
+
up=dict(x=0, y=0, z=1)
|
248 |
+
),
|
249 |
+
aspectmode='data'
|
250 |
+
),
|
251 |
+
width=700,
|
252 |
+
height=600
|
253 |
+
)
|
254 |
+
|
255 |
+
return fig
|
256 |
+
|
257 |
+
def on_depth_submit(image, num_points, focal_x, focal_y):
|
258 |
+
if image is None or depth_model is None:
|
259 |
+
return None, None, None, None, None
|
260 |
+
|
261 |
+
original_image = image.copy()
|
262 |
+
h, w = image.shape[:2]
|
263 |
+
depth = predict_depth(image[:, :, ::-1])
|
264 |
+
|
265 |
+
if depth is None:
|
266 |
+
return None, None, None, None, None
|
267 |
+
|
268 |
+
raw_depth = Image.fromarray(depth.astype('uint16'))
|
269 |
+
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
270 |
+
raw_depth.save(tmp_raw_depth.name)
|
271 |
+
|
272 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
273 |
+
depth = depth.astype(np.uint8)
|
274 |
+
cmap = matplotlib.colormaps.get_cmap('Spectral_r')
|
275 |
+
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
|
276 |
+
|
277 |
+
gray_depth = Image.fromarray(depth)
|
278 |
+
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
279 |
+
gray_depth.save(tmp_gray_depth.name)
|
280 |
+
|
281 |
+
pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points)
|
282 |
+
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
|
283 |
+
o3d.io.write_point_cloud(tmp_pointcloud.name, pcd)
|
284 |
+
|
285 |
+
depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points)
|
286 |
+
|
287 |
+
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
|
288 |
+
|
289 |
+
# --- Custom CSS for Unified Interface ---
|
290 |
+
css = """
|
291 |
+
/* Minimal dark theme styling */
|
292 |
+
.main-header {
|
293 |
+
text-align: center;
|
294 |
+
margin-bottom: 2rem;
|
295 |
+
padding: 2rem 0;
|
296 |
+
}
|
297 |
+
|
298 |
+
.main-header h1 {
|
299 |
+
font-size: 2.5rem;
|
300 |
+
margin-bottom: 0.5rem;
|
301 |
+
font-weight: 600;
|
302 |
+
}
|
303 |
+
|
304 |
+
.main-header p {
|
305 |
+
font-size: 1.1rem;
|
306 |
+
opacity: 0.8;
|
307 |
+
}
|
308 |
+
|
309 |
+
.section-title {
|
310 |
+
font-size: 1.2rem;
|
311 |
+
font-weight: 600;
|
312 |
+
margin-bottom: 15px;
|
313 |
+
padding-bottom: 8px;
|
314 |
+
border-bottom: 1px solid var(--border-color-primary);
|
315 |
+
}
|
316 |
+
|
317 |
+
.confidence-container {
|
318 |
+
margin: 15px 0;
|
319 |
+
padding: 15px;
|
320 |
+
border-radius: 8px;
|
321 |
+
background: var(--background-secondary);
|
322 |
+
border: 1px solid var(--border-color-primary);
|
323 |
+
}
|
324 |
+
|
325 |
+
.confidence-bar {
|
326 |
+
height: 20px;
|
327 |
+
border-radius: 4px;
|
328 |
+
margin: 6px 0;
|
329 |
+
background: var(--primary-500);
|
330 |
+
transition: width 0.3s ease;
|
331 |
+
}
|
332 |
+
|
333 |
+
/* Simple confidence bar colors */
|
334 |
+
.confidence-high {
|
335 |
+
background: var(--success-500);
|
336 |
+
}
|
337 |
+
|
338 |
+
.confidence-medium {
|
339 |
+
background: var(--warning-500);
|
340 |
+
}
|
341 |
+
|
342 |
+
.confidence-low {
|
343 |
+
background: var(--error-500);
|
344 |
+
}
|
345 |
+
|
346 |
+
/* Minimal spacing and layout */
|
347 |
+
.gradio-container {
|
348 |
+
max-width: 100%;
|
349 |
+
margin: 0;
|
350 |
+
padding: 20px;
|
351 |
+
width: 100%;
|
352 |
+
}
|
353 |
+
|
354 |
+
/* Clean image styling */
|
355 |
+
.gradio-image {
|
356 |
+
border-radius: 8px;
|
357 |
+
border: 1px solid var(--border-color-primary);
|
358 |
+
}
|
359 |
+
|
360 |
+
/* Simple button styling */
|
361 |
+
.gradio-button {
|
362 |
+
border-radius: 6px;
|
363 |
+
font-weight: 500;
|
364 |
+
}
|
365 |
+
|
366 |
+
/* Clean form elements */
|
367 |
+
.gradio-textbox, .gradio-number, .gradio-slider {
|
368 |
+
border-radius: 6px;
|
369 |
+
border: 1px solid var(--border-color-primary);
|
370 |
+
}
|
371 |
+
|
372 |
+
/* Tab styling */
|
373 |
+
.gradio-tabs {
|
374 |
+
border-radius: 8px;
|
375 |
+
overflow: hidden;
|
376 |
+
}
|
377 |
+
|
378 |
+
/* File upload styling */
|
379 |
+
.gradio-file {
|
380 |
+
border-radius: 6px;
|
381 |
+
border: 1px solid var(--border-color-primary);
|
382 |
+
}
|
383 |
+
|
384 |
+
/* Plot styling */
|
385 |
+
.gradio-plot {
|
386 |
+
border-radius: 8px;
|
387 |
+
border: 1px solid var(--border-color-primary);
|
388 |
+
}
|
389 |
+
|
390 |
+
/* Full width and height layout */
|
391 |
+
body, html {
|
392 |
+
margin: 0;
|
393 |
+
padding: 0;
|
394 |
+
width: 100%;
|
395 |
+
height: 100%;
|
396 |
+
}
|
397 |
+
|
398 |
+
#root {
|
399 |
+
width: 100%;
|
400 |
+
height: 100%;
|
401 |
+
}
|
402 |
+
|
403 |
+
/* Ensure Gradio uses full width */
|
404 |
+
.gradio-container {
|
405 |
+
min-height: 100vh;
|
406 |
+
}
|
407 |
+
|
408 |
+
/* Responsive adjustments */
|
409 |
+
@media (max-width: 768px) {
|
410 |
+
.main-header h1 {
|
411 |
+
font-size: 2rem;
|
412 |
+
}
|
413 |
+
|
414 |
+
.gradio-container {
|
415 |
+
padding: 10px;
|
416 |
+
}
|
417 |
+
}
|
418 |
+
"""
|
419 |
+
|
420 |
+
# --- Create Unified Interface ---
|
421 |
+
with gr.Blocks(css=css, title="Medical AI Suite") as demo:
|
422 |
+
gr.HTML("""
|
423 |
+
<div class="main-header">
|
424 |
+
<h1>Medical AI Suite</h1>
|
425 |
+
<p>Advanced AI-powered medical image analysis and 3D visualization</p>
|
426 |
+
</div>
|
427 |
+
""")
|
428 |
+
|
429 |
+
with gr.Tabs() as tabs:
|
430 |
+
# Tab 1: Wound Classification
|
431 |
+
with gr.TabItem("Wound Classification", id=0):
|
432 |
+
gr.HTML("<div class='section-title'>Wound Classification with Grad-CAM Visualization</div>")
|
433 |
+
|
434 |
+
with gr.Row():
|
435 |
+
with gr.Column(scale=1):
|
436 |
+
gr.HTML("<div class='section-title'>Input Image</div>")
|
437 |
+
wound_input_image = gr.Image(
|
438 |
+
label="Upload wound image",
|
439 |
+
type="pil",
|
440 |
+
height=350,
|
441 |
+
container=True
|
442 |
+
)
|
443 |
+
|
444 |
+
with gr.Column(scale=1):
|
445 |
+
gr.HTML("<div class='section-title'>Analysis Results</div>")
|
446 |
+
wound_prediction_output = gr.Textbox(
|
447 |
+
label="Predicted Wound Type",
|
448 |
+
interactive=False,
|
449 |
+
container=True
|
450 |
+
)
|
451 |
+
wound_confidence_output = gr.Number(
|
452 |
+
label="Confidence Score",
|
453 |
+
interactive=False,
|
454 |
+
container=True
|
455 |
+
)
|
456 |
+
wound_confidence_bars = gr.HTML(
|
457 |
+
label="Confidence Scores by Class",
|
458 |
+
container=True
|
459 |
+
)
|
460 |
+
|
461 |
+
with gr.Row():
|
462 |
+
with gr.Column():
|
463 |
+
gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
|
464 |
+
wound_cam_output = gr.Image(
|
465 |
+
label="Grad-CAM Heatmap - Shows which areas the model focused on",
|
466 |
+
height=350,
|
467 |
+
container=True
|
468 |
+
)
|
469 |
+
|
470 |
+
# Event handlers for wound classification
|
471 |
+
wound_input_image.change(
|
472 |
+
fn=enhanced_classify_and_explain,
|
473 |
+
inputs=[wound_input_image],
|
474 |
+
outputs=[wound_cam_output, wound_prediction_output, wound_confidence_output, wound_confidence_bars]
|
475 |
+
)
|
476 |
+
|
477 |
+
# Tab 2: Depth Estimation
|
478 |
+
with gr.TabItem("Depth Estimation & 3D Visualization", id=1):
|
479 |
+
gr.HTML("<div class='section-title'>Depth Estimation and 3D Point Cloud Generation</div>")
|
480 |
+
|
481 |
+
with gr.Row():
|
482 |
+
depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
|
483 |
+
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
|
484 |
+
|
485 |
+
with gr.Row():
|
486 |
+
depth_submit = gr.Button(value="Compute Depth", variant="primary")
|
487 |
+
depth_points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
|
488 |
+
label="Number of 3D points (upload image to update max)")
|
489 |
+
|
490 |
+
with gr.Row():
|
491 |
+
depth_focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
492 |
+
label="Focal Length X (pixels)")
|
493 |
+
depth_focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
|
494 |
+
label="Focal Length Y (pixels)")
|
495 |
+
|
496 |
+
with gr.Row():
|
497 |
+
depth_gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
|
498 |
+
depth_raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
|
499 |
+
depth_point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
|
500 |
+
|
501 |
+
gr.Markdown("### 3D Point Cloud Visualization")
|
502 |
+
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
|
503 |
+
depth_3d_plot = gr.Plot(label="3D Point Cloud")
|
504 |
+
|
505 |
+
# Event handlers for depth estimation
|
506 |
+
depth_input_image.change(
|
507 |
+
fn=update_slider_on_image_upload,
|
508 |
+
inputs=[depth_input_image],
|
509 |
+
outputs=[depth_points_slider]
|
510 |
+
)
|
511 |
+
|
512 |
+
depth_submit.click(
|
513 |
+
on_depth_submit,
|
514 |
+
inputs=[depth_input_image, depth_points_slider, depth_focal_length_x, depth_focal_length_y],
|
515 |
+
outputs=[depth_image_slider, depth_gray_depth_file, depth_raw_file, depth_point_cloud_file, depth_3d_plot]
|
516 |
+
)
|
517 |
+
|
518 |
+
# Cross-tab image sharing functionality
|
519 |
+
# When image is uploaded in wound classification, also update depth estimation
|
520 |
+
wound_input_image.change(
|
521 |
+
fn=lambda img: img,
|
522 |
+
inputs=[wound_input_image],
|
523 |
+
outputs=[depth_input_image]
|
524 |
+
)
|
525 |
+
|
526 |
+
# When image is uploaded in depth estimation, also update wound classification
|
527 |
+
depth_input_image.change(
|
528 |
+
fn=lambda img: img,
|
529 |
+
inputs=[depth_input_image],
|
530 |
+
outputs=[wound_input_image]
|
531 |
+
)
|
532 |
+
|
533 |
+
# --- Launch the unified interface ---
|
534 |
+
if __name__ == "__main__":
|
535 |
+
demo.queue().launch(
|
536 |
+
server_name="0.0.0.0",
|
537 |
+
server_port=7860,
|
538 |
+
share=True,
|
539 |
+
show_error=True
|
540 |
+
)
|
metric_depth/README.md
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Depth Anything V2 for Metric Depth Estimation
|
2 |
+
|
3 |
+

|
4 |
+
|
5 |
+
We here provide a simple codebase to fine-tune our Depth Anything V2 pre-trained encoder for metric depth estimation. Built on our powerful encoder, we use a simple DPT head to regress the depth. We fine-tune our pre-trained encoder on synthetic Hypersim / Virtual KITTI datasets for indoor / outdoor metric depth estimation, respectively.
|
6 |
+
|
7 |
+
|
8 |
+
# Pre-trained Models
|
9 |
+
|
10 |
+
We provide **six metric depth models** of three scales for indoor and outdoor scenes, respectively.
|
11 |
+
|
12 |
+
| Base Model | Params | Indoor (Hypersim) | Outdoor (Virtual KITTI 2) |
|
13 |
+
|:-|-:|:-:|:-:|
|
14 |
+
| Depth-Anything-V2-Small | 24.8M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Small/resolve/main/depth_anything_v2_metric_hypersim_vits.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Small/resolve/main/depth_anything_v2_metric_vkitti_vits.pth?download=true) |
|
15 |
+
| Depth-Anything-V2-Base | 97.5M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Base/resolve/main/depth_anything_v2_metric_hypersim_vitb.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Base/resolve/main/depth_anything_v2_metric_vkitti_vitb.pth?download=true) |
|
16 |
+
| Depth-Anything-V2-Large | 335.3M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Large/resolve/main/depth_anything_v2_metric_hypersim_vitl.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Large/resolve/main/depth_anything_v2_metric_vkitti_vitl.pth?download=true) |
|
17 |
+
|
18 |
+
*We recommend to first try our larger models (if computational cost is affordable) and the indoor version.*
|
19 |
+
|
20 |
+
## Usage
|
21 |
+
|
22 |
+
### Prepraration
|
23 |
+
|
24 |
+
```bash
|
25 |
+
git clone https://github.com/DepthAnything/Depth-Anything-V2
|
26 |
+
cd Depth-Anything-V2/metric_depth
|
27 |
+
pip install -r requirements.txt
|
28 |
+
```
|
29 |
+
|
30 |
+
Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
|
31 |
+
|
32 |
+
### Use our models
|
33 |
+
```python
|
34 |
+
import cv2
|
35 |
+
import torch
|
36 |
+
|
37 |
+
from depth_anything_v2.dpt import DepthAnythingV2
|
38 |
+
|
39 |
+
model_configs = {
|
40 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
41 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
42 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
|
43 |
+
}
|
44 |
+
|
45 |
+
encoder = 'vitl' # or 'vits', 'vitb'
|
46 |
+
dataset = 'hypersim' # 'hypersim' for indoor model, 'vkitti' for outdoor model
|
47 |
+
max_depth = 20 # 20 for indoor model, 80 for outdoor model
|
48 |
+
|
49 |
+
model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
|
50 |
+
model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_metric_{dataset}_{encoder}.pth', map_location='cpu'))
|
51 |
+
model.eval()
|
52 |
+
|
53 |
+
raw_img = cv2.imread('your/image/path')
|
54 |
+
depth = model.infer_image(raw_img) # HxW depth map in meters in numpy
|
55 |
+
```
|
56 |
+
|
57 |
+
### Running script on images
|
58 |
+
|
59 |
+
Here, we take the `vitl` encoder as an example. You can also use `vitb` or `vits` encoders.
|
60 |
+
|
61 |
+
```bash
|
62 |
+
# indoor scenes
|
63 |
+
python run.py \
|
64 |
+
--encoder vitl \
|
65 |
+
--load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
|
66 |
+
--max-depth 20 \
|
67 |
+
--img-path <path> --outdir <outdir> [--input-size <size>] [--save-numpy]
|
68 |
+
|
69 |
+
# outdoor scenes
|
70 |
+
python run.py \
|
71 |
+
--encoder vitl \
|
72 |
+
--load-from checkpoints/depth_anything_v2_metric_vkitti_vitl.pth \
|
73 |
+
--max-depth 80 \
|
74 |
+
--img-path <path> --outdir <outdir> [--input-size <size>] [--save-numpy]
|
75 |
+
```
|
76 |
+
|
77 |
+
### Project 2D images to point clouds:
|
78 |
+
|
79 |
+
```bash
|
80 |
+
python depth_to_pointcloud.py \
|
81 |
+
--encoder vitl \
|
82 |
+
--load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
|
83 |
+
--max-depth 20 \
|
84 |
+
--img-path <path> --outdir <outdir>
|
85 |
+
```
|
86 |
+
|
87 |
+
### Reproduce training
|
88 |
+
|
89 |
+
Please first prepare the [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets. Then:
|
90 |
+
|
91 |
+
```bash
|
92 |
+
bash dist_train.sh
|
93 |
+
```
|
94 |
+
|
95 |
+
|
96 |
+
## Citation
|
97 |
+
|
98 |
+
If you find this project useful, please consider citing:
|
99 |
+
|
100 |
+
```bibtex
|
101 |
+
@article{depth_anything_v2,
|
102 |
+
title={Depth Anything V2},
|
103 |
+
author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Zhao, Zhen and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
|
104 |
+
journal={arXiv:2406.09414},
|
105 |
+
year={2024}
|
106 |
+
}
|
107 |
+
|
108 |
+
@inproceedings{depth_anything_v1,
|
109 |
+
title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
|
110 |
+
author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
|
111 |
+
booktitle={CVPR},
|
112 |
+
year={2024}
|
113 |
+
}
|
114 |
+
```
|
metric_depth/assets/compare_zoedepth.png
ADDED
![]() |
Git LFS Details
|
metric_depth/dataset/hypersim.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import h5py
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision.transforms import Compose
|
7 |
+
|
8 |
+
from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
|
9 |
+
|
10 |
+
|
11 |
+
def hypersim_distance_to_depth(npyDistance):
|
12 |
+
intWidth, intHeight, fltFocal = 1024, 768, 886.81
|
13 |
+
|
14 |
+
npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
|
15 |
+
1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
|
16 |
+
npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
|
17 |
+
intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
|
18 |
+
npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
|
19 |
+
npyImageplane = np.concatenate(
|
20 |
+
[npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
|
21 |
+
|
22 |
+
npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
|
23 |
+
return npyDepth
|
24 |
+
|
25 |
+
|
26 |
+
class Hypersim(Dataset):
|
27 |
+
def __init__(self, filelist_path, mode, size=(518, 518)):
|
28 |
+
|
29 |
+
self.mode = mode
|
30 |
+
self.size = size
|
31 |
+
|
32 |
+
with open(filelist_path, 'r') as f:
|
33 |
+
self.filelist = f.read().splitlines()
|
34 |
+
|
35 |
+
net_w, net_h = size
|
36 |
+
self.transform = Compose([
|
37 |
+
Resize(
|
38 |
+
width=net_w,
|
39 |
+
height=net_h,
|
40 |
+
resize_target=True if mode == 'train' else False,
|
41 |
+
keep_aspect_ratio=True,
|
42 |
+
ensure_multiple_of=14,
|
43 |
+
resize_method='lower_bound',
|
44 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
45 |
+
),
|
46 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
47 |
+
PrepareForNet(),
|
48 |
+
] + ([Crop(size[0])] if self.mode == 'train' else []))
|
49 |
+
|
50 |
+
def __getitem__(self, item):
|
51 |
+
img_path = self.filelist[item].split(' ')[0]
|
52 |
+
depth_path = self.filelist[item].split(' ')[1]
|
53 |
+
|
54 |
+
image = cv2.imread(img_path)
|
55 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
|
56 |
+
|
57 |
+
depth_fd = h5py.File(depth_path, "r")
|
58 |
+
distance_meters = np.array(depth_fd['dataset'])
|
59 |
+
depth = hypersim_distance_to_depth(distance_meters)
|
60 |
+
|
61 |
+
sample = self.transform({'image': image, 'depth': depth})
|
62 |
+
|
63 |
+
sample['image'] = torch.from_numpy(sample['image'])
|
64 |
+
sample['depth'] = torch.from_numpy(sample['depth'])
|
65 |
+
|
66 |
+
sample['valid_mask'] = (torch.isnan(sample['depth']) == 0)
|
67 |
+
sample['depth'][sample['valid_mask'] == 0] = 0
|
68 |
+
|
69 |
+
sample['image_path'] = self.filelist[item].split(' ')[0]
|
70 |
+
|
71 |
+
return sample
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return len(self.filelist)
|
metric_depth/dataset/kitti.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from torchvision.transforms import Compose
|
5 |
+
|
6 |
+
from dataset.transform import Resize, NormalizeImage, PrepareForNet
|
7 |
+
|
8 |
+
|
9 |
+
class KITTI(Dataset):
|
10 |
+
def __init__(self, filelist_path, mode, size=(518, 518)):
|
11 |
+
if mode != 'val':
|
12 |
+
raise NotImplementedError
|
13 |
+
|
14 |
+
self.mode = mode
|
15 |
+
self.size = size
|
16 |
+
|
17 |
+
with open(filelist_path, 'r') as f:
|
18 |
+
self.filelist = f.read().splitlines()
|
19 |
+
|
20 |
+
net_w, net_h = size
|
21 |
+
self.transform = Compose([
|
22 |
+
Resize(
|
23 |
+
width=net_w,
|
24 |
+
height=net_h,
|
25 |
+
resize_target=True if mode == 'train' else False,
|
26 |
+
keep_aspect_ratio=True,
|
27 |
+
ensure_multiple_of=14,
|
28 |
+
resize_method='lower_bound',
|
29 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
30 |
+
),
|
31 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
32 |
+
PrepareForNet(),
|
33 |
+
])
|
34 |
+
|
35 |
+
def __getitem__(self, item):
|
36 |
+
img_path = self.filelist[item].split(' ')[0]
|
37 |
+
depth_path = self.filelist[item].split(' ')[1]
|
38 |
+
|
39 |
+
image = cv2.imread(img_path)
|
40 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
|
41 |
+
|
42 |
+
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype('float32')
|
43 |
+
|
44 |
+
sample = self.transform({'image': image, 'depth': depth})
|
45 |
+
|
46 |
+
sample['image'] = torch.from_numpy(sample['image'])
|
47 |
+
sample['depth'] = torch.from_numpy(sample['depth'])
|
48 |
+
sample['depth'] = sample['depth'] / 256.0 # convert in meters
|
49 |
+
|
50 |
+
sample['valid_mask'] = sample['depth'] > 0
|
51 |
+
|
52 |
+
sample['image_path'] = self.filelist[item].split(' ')[0]
|
53 |
+
|
54 |
+
return sample
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.filelist)
|
metric_depth/dataset/splits/hypersim/val.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
metric_depth/dataset/splits/kitti/val.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
metric_depth/dataset/splits/vkitti2/train.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
metric_depth/dataset/transform.py
ADDED
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import math
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
|
8 |
+
def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
|
9 |
+
"""Rezise the sample to ensure the given size. Keeps aspect ratio.
|
10 |
+
|
11 |
+
Args:
|
12 |
+
sample (dict): sample
|
13 |
+
size (tuple): image size
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
tuple: new size
|
17 |
+
"""
|
18 |
+
shape = list(sample["disparity"].shape)
|
19 |
+
|
20 |
+
if shape[0] >= size[0] and shape[1] >= size[1]:
|
21 |
+
return sample
|
22 |
+
|
23 |
+
scale = [0, 0]
|
24 |
+
scale[0] = size[0] / shape[0]
|
25 |
+
scale[1] = size[1] / shape[1]
|
26 |
+
|
27 |
+
scale = max(scale)
|
28 |
+
|
29 |
+
shape[0] = math.ceil(scale * shape[0])
|
30 |
+
shape[1] = math.ceil(scale * shape[1])
|
31 |
+
|
32 |
+
# resize
|
33 |
+
sample["image"] = cv2.resize(
|
34 |
+
sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
|
35 |
+
)
|
36 |
+
|
37 |
+
sample["disparity"] = cv2.resize(
|
38 |
+
sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
|
39 |
+
)
|
40 |
+
sample["mask"] = cv2.resize(
|
41 |
+
sample["mask"].astype(np.float32),
|
42 |
+
tuple(shape[::-1]),
|
43 |
+
interpolation=cv2.INTER_NEAREST,
|
44 |
+
)
|
45 |
+
sample["mask"] = sample["mask"].astype(bool)
|
46 |
+
|
47 |
+
return tuple(shape)
|
48 |
+
|
49 |
+
|
50 |
+
class Resize(object):
|
51 |
+
"""Resize sample to given size (width, height).
|
52 |
+
"""
|
53 |
+
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
width,
|
57 |
+
height,
|
58 |
+
resize_target=True,
|
59 |
+
keep_aspect_ratio=False,
|
60 |
+
ensure_multiple_of=1,
|
61 |
+
resize_method="lower_bound",
|
62 |
+
image_interpolation_method=cv2.INTER_AREA,
|
63 |
+
):
|
64 |
+
"""Init.
|
65 |
+
|
66 |
+
Args:
|
67 |
+
width (int): desired output width
|
68 |
+
height (int): desired output height
|
69 |
+
resize_target (bool, optional):
|
70 |
+
True: Resize the full sample (image, mask, target).
|
71 |
+
False: Resize image only.
|
72 |
+
Defaults to True.
|
73 |
+
keep_aspect_ratio (bool, optional):
|
74 |
+
True: Keep the aspect ratio of the input sample.
|
75 |
+
Output sample might not have the given width and height, and
|
76 |
+
resize behaviour depends on the parameter 'resize_method'.
|
77 |
+
Defaults to False.
|
78 |
+
ensure_multiple_of (int, optional):
|
79 |
+
Output width and height is constrained to be multiple of this parameter.
|
80 |
+
Defaults to 1.
|
81 |
+
resize_method (str, optional):
|
82 |
+
"lower_bound": Output will be at least as large as the given size.
|
83 |
+
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
|
84 |
+
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
|
85 |
+
Defaults to "lower_bound".
|
86 |
+
"""
|
87 |
+
self.__width = width
|
88 |
+
self.__height = height
|
89 |
+
|
90 |
+
self.__resize_target = resize_target
|
91 |
+
self.__keep_aspect_ratio = keep_aspect_ratio
|
92 |
+
self.__multiple_of = ensure_multiple_of
|
93 |
+
self.__resize_method = resize_method
|
94 |
+
self.__image_interpolation_method = image_interpolation_method
|
95 |
+
|
96 |
+
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
|
97 |
+
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
98 |
+
|
99 |
+
if max_val is not None and y > max_val:
|
100 |
+
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
101 |
+
|
102 |
+
if y < min_val:
|
103 |
+
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
|
104 |
+
|
105 |
+
return y
|
106 |
+
|
107 |
+
def get_size(self, width, height):
|
108 |
+
# determine new height and width
|
109 |
+
scale_height = self.__height / height
|
110 |
+
scale_width = self.__width / width
|
111 |
+
|
112 |
+
if self.__keep_aspect_ratio:
|
113 |
+
if self.__resize_method == "lower_bound":
|
114 |
+
# scale such that output size is lower bound
|
115 |
+
if scale_width > scale_height:
|
116 |
+
# fit width
|
117 |
+
scale_height = scale_width
|
118 |
+
else:
|
119 |
+
# fit height
|
120 |
+
scale_width = scale_height
|
121 |
+
elif self.__resize_method == "upper_bound":
|
122 |
+
# scale such that output size is upper bound
|
123 |
+
if scale_width < scale_height:
|
124 |
+
# fit width
|
125 |
+
scale_height = scale_width
|
126 |
+
else:
|
127 |
+
# fit height
|
128 |
+
scale_width = scale_height
|
129 |
+
elif self.__resize_method == "minimal":
|
130 |
+
# scale as least as possbile
|
131 |
+
if abs(1 - scale_width) < abs(1 - scale_height):
|
132 |
+
# fit width
|
133 |
+
scale_height = scale_width
|
134 |
+
else:
|
135 |
+
# fit height
|
136 |
+
scale_width = scale_height
|
137 |
+
else:
|
138 |
+
raise ValueError(
|
139 |
+
f"resize_method {self.__resize_method} not implemented"
|
140 |
+
)
|
141 |
+
|
142 |
+
if self.__resize_method == "lower_bound":
|
143 |
+
new_height = self.constrain_to_multiple_of(
|
144 |
+
scale_height * height, min_val=self.__height
|
145 |
+
)
|
146 |
+
new_width = self.constrain_to_multiple_of(
|
147 |
+
scale_width * width, min_val=self.__width
|
148 |
+
)
|
149 |
+
elif self.__resize_method == "upper_bound":
|
150 |
+
new_height = self.constrain_to_multiple_of(
|
151 |
+
scale_height * height, max_val=self.__height
|
152 |
+
)
|
153 |
+
new_width = self.constrain_to_multiple_of(
|
154 |
+
scale_width * width, max_val=self.__width
|
155 |
+
)
|
156 |
+
elif self.__resize_method == "minimal":
|
157 |
+
new_height = self.constrain_to_multiple_of(scale_height * height)
|
158 |
+
new_width = self.constrain_to_multiple_of(scale_width * width)
|
159 |
+
else:
|
160 |
+
raise ValueError(f"resize_method {self.__resize_method} not implemented")
|
161 |
+
|
162 |
+
return (new_width, new_height)
|
163 |
+
|
164 |
+
def __call__(self, sample):
|
165 |
+
width, height = self.get_size(
|
166 |
+
sample["image"].shape[1], sample["image"].shape[0]
|
167 |
+
)
|
168 |
+
|
169 |
+
# resize sample
|
170 |
+
sample["image"] = cv2.resize(
|
171 |
+
sample["image"],
|
172 |
+
(width, height),
|
173 |
+
interpolation=self.__image_interpolation_method,
|
174 |
+
)
|
175 |
+
|
176 |
+
if self.__resize_target:
|
177 |
+
if "disparity" in sample:
|
178 |
+
sample["disparity"] = cv2.resize(
|
179 |
+
sample["disparity"],
|
180 |
+
(width, height),
|
181 |
+
interpolation=cv2.INTER_NEAREST,
|
182 |
+
)
|
183 |
+
|
184 |
+
if "depth" in sample:
|
185 |
+
sample["depth"] = cv2.resize(
|
186 |
+
sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
|
187 |
+
)
|
188 |
+
|
189 |
+
if "semseg_mask" in sample:
|
190 |
+
# sample["semseg_mask"] = cv2.resize(
|
191 |
+
# sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
|
192 |
+
# )
|
193 |
+
sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
|
194 |
+
|
195 |
+
if "mask" in sample:
|
196 |
+
sample["mask"] = cv2.resize(
|
197 |
+
sample["mask"].astype(np.float32),
|
198 |
+
(width, height),
|
199 |
+
interpolation=cv2.INTER_NEAREST,
|
200 |
+
)
|
201 |
+
# sample["mask"] = sample["mask"].astype(bool)
|
202 |
+
|
203 |
+
# print(sample['image'].shape, sample['depth'].shape)
|
204 |
+
return sample
|
205 |
+
|
206 |
+
|
207 |
+
class NormalizeImage(object):
|
208 |
+
"""Normlize image by given mean and std.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, mean, std):
|
212 |
+
self.__mean = mean
|
213 |
+
self.__std = std
|
214 |
+
|
215 |
+
def __call__(self, sample):
|
216 |
+
sample["image"] = (sample["image"] - self.__mean) / self.__std
|
217 |
+
|
218 |
+
return sample
|
219 |
+
|
220 |
+
|
221 |
+
class PrepareForNet(object):
|
222 |
+
"""Prepare sample for usage as network input.
|
223 |
+
"""
|
224 |
+
|
225 |
+
def __init__(self):
|
226 |
+
pass
|
227 |
+
|
228 |
+
def __call__(self, sample):
|
229 |
+
image = np.transpose(sample["image"], (2, 0, 1))
|
230 |
+
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
|
231 |
+
|
232 |
+
if "mask" in sample:
|
233 |
+
sample["mask"] = sample["mask"].astype(np.float32)
|
234 |
+
sample["mask"] = np.ascontiguousarray(sample["mask"])
|
235 |
+
|
236 |
+
if "depth" in sample:
|
237 |
+
depth = sample["depth"].astype(np.float32)
|
238 |
+
sample["depth"] = np.ascontiguousarray(depth)
|
239 |
+
|
240 |
+
if "semseg_mask" in sample:
|
241 |
+
sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
|
242 |
+
sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
|
243 |
+
|
244 |
+
return sample
|
245 |
+
|
246 |
+
|
247 |
+
class Crop(object):
|
248 |
+
"""Crop sample for batch-wise training. Image is of shape CxHxW
|
249 |
+
"""
|
250 |
+
|
251 |
+
def __init__(self, size):
|
252 |
+
if isinstance(size, int):
|
253 |
+
self.size = (size, size)
|
254 |
+
else:
|
255 |
+
self.size = size
|
256 |
+
|
257 |
+
def __call__(self, sample):
|
258 |
+
h, w = sample['image'].shape[-2:]
|
259 |
+
assert h >= self.size[0] and w >= self.size[1], 'Wrong size'
|
260 |
+
|
261 |
+
h_start = np.random.randint(0, h - self.size[0] + 1)
|
262 |
+
w_start = np.random.randint(0, w - self.size[1] + 1)
|
263 |
+
h_end = h_start + self.size[0]
|
264 |
+
w_end = w_start + self.size[1]
|
265 |
+
|
266 |
+
sample['image'] = sample['image'][:, h_start: h_end, w_start: w_end]
|
267 |
+
|
268 |
+
if "depth" in sample:
|
269 |
+
sample["depth"] = sample["depth"][h_start: h_end, w_start: w_end]
|
270 |
+
|
271 |
+
if "mask" in sample:
|
272 |
+
sample["mask"] = sample["mask"][h_start: h_end, w_start: w_end]
|
273 |
+
|
274 |
+
if "semseg_mask" in sample:
|
275 |
+
sample["semseg_mask"] = sample["semseg_mask"][h_start: h_end, w_start: w_end]
|
276 |
+
|
277 |
+
return sample
|
metric_depth/dataset/vkitti2.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
from torchvision.transforms import Compose
|
5 |
+
|
6 |
+
from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
|
7 |
+
|
8 |
+
|
9 |
+
class VKITTI2(Dataset):
|
10 |
+
def __init__(self, filelist_path, mode, size=(518, 518)):
|
11 |
+
|
12 |
+
self.mode = mode
|
13 |
+
self.size = size
|
14 |
+
|
15 |
+
with open(filelist_path, 'r') as f:
|
16 |
+
self.filelist = f.read().splitlines()
|
17 |
+
|
18 |
+
net_w, net_h = size
|
19 |
+
self.transform = Compose([
|
20 |
+
Resize(
|
21 |
+
width=net_w,
|
22 |
+
height=net_h,
|
23 |
+
resize_target=True if mode == 'train' else False,
|
24 |
+
keep_aspect_ratio=True,
|
25 |
+
ensure_multiple_of=14,
|
26 |
+
resize_method='lower_bound',
|
27 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
28 |
+
),
|
29 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
30 |
+
PrepareForNet(),
|
31 |
+
] + ([Crop(size[0])] if self.mode == 'train' else []))
|
32 |
+
|
33 |
+
def __getitem__(self, item):
|
34 |
+
img_path = self.filelist[item].split(' ')[0]
|
35 |
+
depth_path = self.filelist[item].split(' ')[1]
|
36 |
+
|
37 |
+
image = cv2.imread(img_path)
|
38 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
|
39 |
+
|
40 |
+
depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
|
41 |
+
|
42 |
+
sample = self.transform({'image': image, 'depth': depth})
|
43 |
+
|
44 |
+
sample['image'] = torch.from_numpy(sample['image'])
|
45 |
+
sample['depth'] = torch.from_numpy(sample['depth'])
|
46 |
+
|
47 |
+
sample['valid_mask'] = (sample['depth'] <= 80)
|
48 |
+
|
49 |
+
sample['image_path'] = self.filelist[item].split(' ')[0]
|
50 |
+
|
51 |
+
return sample
|
52 |
+
|
53 |
+
def __len__(self):
|
54 |
+
return len(self.filelist)
|
metric_depth/depth_anything_v2/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
metric_depth/depth_anything_v2/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
metric_depth/depth_anything_v2/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
metric_depth/depth_anything_v2/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|