viranchi123 commited on
Commit
ca46f55
·
verified ·
1 Parent(s): c315ee6

Upload 75 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gradio/certificate.pem +31 -0
  3. README_Unified.md +128 -0
  4. app.py +299 -0
  5. app2.py +324 -0
  6. checkpoints/labels.txt +4 -0
  7. depth_anything_v2/__pycache__/dinov2.cpython-312.pyc +0 -0
  8. depth_anything_v2/__pycache__/dpt.cpython-312.pyc +0 -0
  9. depth_anything_v2/dinov2.py +415 -0
  10. depth_anything_v2/dinov2_layers/__init__.py +11 -0
  11. depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-312.pyc +0 -0
  12. depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-312.pyc +0 -0
  13. depth_anything_v2/dinov2_layers/__pycache__/block.cpython-312.pyc +0 -0
  14. depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-312.pyc +0 -0
  15. depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-312.pyc +0 -0
  16. depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-312.pyc +0 -0
  17. depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-312.pyc +0 -0
  18. depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-312.pyc +0 -0
  19. depth_anything_v2/dinov2_layers/attention.py +83 -0
  20. depth_anything_v2/dinov2_layers/block.py +252 -0
  21. depth_anything_v2/dinov2_layers/drop_path.py +35 -0
  22. depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
  23. depth_anything_v2/dinov2_layers/mlp.py +41 -0
  24. depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
  25. depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
  26. depth_anything_v2/dpt.py +221 -0
  27. depth_anything_v2/util/__pycache__/blocks.cpython-312.pyc +0 -0
  28. depth_anything_v2/util/__pycache__/transform.cpython-312.pyc +0 -0
  29. depth_anything_v2/util/blocks.py +148 -0
  30. depth_anything_v2/util/transform.py +158 -0
  31. environment.yml +0 -0
  32. environment_export.yml +182 -0
  33. environment_from_history.yml +30 -0
  34. environment_linux.yml +116 -0
  35. keras_model_3.h5 +3 -0
  36. labels.txt +4 -0
  37. main_app.py +540 -0
  38. metric_depth/README.md +114 -0
  39. metric_depth/assets/compare_zoedepth.png +3 -0
  40. metric_depth/dataset/hypersim.py +74 -0
  41. metric_depth/dataset/kitti.py +57 -0
  42. metric_depth/dataset/splits/hypersim/val.txt +0 -0
  43. metric_depth/dataset/splits/kitti/val.txt +0 -0
  44. metric_depth/dataset/splits/vkitti2/train.txt +0 -0
  45. metric_depth/dataset/transform.py +277 -0
  46. metric_depth/dataset/vkitti2.py +54 -0
  47. metric_depth/depth_anything_v2/dinov2.py +415 -0
  48. metric_depth/depth_anything_v2/dinov2_layers/__init__.py +11 -0
  49. metric_depth/depth_anything_v2/dinov2_layers/attention.py +83 -0
  50. metric_depth/depth_anything_v2/dinov2_layers/block.py +252 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ metric_depth/assets/compare_zoedepth.png filter=lfs diff=lfs merge=lfs -text
.gradio/certificate.pem ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ -----BEGIN CERTIFICATE-----
2
+ MIIFazCCA1OgAwIBAgIRAIIQz7DSQONZRGPgu2OCiwAwDQYJKoZIhvcNAQELBQAw
3
+ TzELMAkGA1UEBhMCVVMxKTAnBgNVBAoTIEludGVybmV0IFNlY3VyaXR5IFJlc2Vh
4
+ cmNoIEdyb3VwMRUwEwYDVQQDEwxJU1JHIFJvb3QgWDEwHhcNMTUwNjA0MTEwNDM4
5
+ WhcNMzUwNjA0MTEwNDM4WjBPMQswCQYDVQQGEwJVUzEpMCcGA1UEChMgSW50ZXJu
6
+ ZXQgU2VjdXJpdHkgUmVzZWFyY2ggR3JvdXAxFTATBgNVBAMTDElTUkcgUm9vdCBY
7
+ MTCCAiIwDQYJKoZIhvcNAQEBBQADggIPADCCAgoCggIBAK3oJHP0FDfzm54rVygc
8
+ h77ct984kIxuPOZXoHj3dcKi/vVqbvYATyjb3miGbESTtrFj/RQSa78f0uoxmyF+
9
+ 0TM8ukj13Xnfs7j/EvEhmkvBioZxaUpmZmyPfjxwv60pIgbz5MDmgK7iS4+3mX6U
10
+ A5/TR5d8mUgjU+g4rk8Kb4Mu0UlXjIB0ttov0DiNewNwIRt18jA8+o+u3dpjq+sW
11
+ T8KOEUt+zwvo/7V3LvSye0rgTBIlDHCNAymg4VMk7BPZ7hm/ELNKjD+Jo2FR3qyH
12
+ B5T0Y3HsLuJvW5iB4YlcNHlsdu87kGJ55tukmi8mxdAQ4Q7e2RCOFvu396j3x+UC
13
+ B5iPNgiV5+I3lg02dZ77DnKxHZu8A/lJBdiB3QW0KtZB6awBdpUKD9jf1b0SHzUv
14
+ KBds0pjBqAlkd25HN7rOrFleaJ1/ctaJxQZBKT5ZPt0m9STJEadao0xAH0ahmbWn
15
+ OlFuhjuefXKnEgV4We0+UXgVCwOPjdAvBbI+e0ocS3MFEvzG6uBQE3xDk3SzynTn
16
+ jh8BCNAw1FtxNrQHusEwMFxIt4I7mKZ9YIqioymCzLq9gwQbooMDQaHWBfEbwrbw
17
+ qHyGO0aoSCqI3Haadr8faqU9GY/rOPNk3sgrDQoo//fb4hVC1CLQJ13hef4Y53CI
18
+ rU7m2Ys6xt0nUW7/vGT1M0NPAgMBAAGjQjBAMA4GA1UdDwEB/wQEAwIBBjAPBgNV
19
+ HRMBAf8EBTADAQH/MB0GA1UdDgQWBBR5tFnme7bl5AFzgAiIyBpY9umbbjANBgkq
20
+ hkiG9w0BAQsFAAOCAgEAVR9YqbyyqFDQDLHYGmkgJykIrGF1XIpu+ILlaS/V9lZL
21
+ ubhzEFnTIZd+50xx+7LSYK05qAvqFyFWhfFQDlnrzuBZ6brJFe+GnY+EgPbk6ZGQ
22
+ 3BebYhtF8GaV0nxvwuo77x/Py9auJ/GpsMiu/X1+mvoiBOv/2X/qkSsisRcOj/KK
23
+ NFtY2PwByVS5uCbMiogziUwthDyC3+6WVwW6LLv3xLfHTjuCvjHIInNzktHCgKQ5
24
+ ORAzI4JMPJ+GslWYHb4phowim57iaztXOoJwTdwJx4nLCgdNbOhdjsnvzqvHu7Ur
25
+ TkXWStAmzOVyyghqpZXjFaH3pO3JLF+l+/+sKAIuvtd7u+Nxe5AW0wdeRlN8NwdC
26
+ jNPElpzVmbUq4JUagEiuTDkHzsxHpFKVK7q4+63SM1N95R1NbdWhscdCb+ZAJzVc
27
+ oyi3B43njTOQ5yOf+1CceWxG1bQVs5ZufpsMljq4Ui0/1lvh+wjChP4kqKOJ2qxq
28
+ 4RgqsahDYVvTH9w7jXbyLeiNdd8XM2w9U/t7y0Ff/9yi0GE44Za4rF2LN9d11TPA
29
+ mRGunUHBcnWEvgJBQl9nJEiU0Zsnvgc/ubhPgXRR4Xq37Z0j4r7g1SgEEzwxA57d
30
+ emyPxgcYxn/eR44/KJ4EBs+lVDR3veyJm+kXQ99b21/+jh5Xos1AnX5iItreGCc=
31
+ -----END CERTIFICATE-----
README_Unified.md ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Medical AI Suite - Unified Interface
2
+
3
+ A comprehensive web application that combines wound classification and depth estimation capabilities in a single, modern interface.
4
+
5
+ ## 🚀 Quick Start
6
+
7
+ ### Option 1: Use the Launcher (Recommended)
8
+ ```bash
9
+ python launcher.py
10
+ ```
11
+ This will show you a menu to choose which application to run.
12
+
13
+ ### Option 2: Run the Unified Interface Directly
14
+ ```bash
15
+ python main_app.py
16
+ ```
17
+
18
+ ### Option 3: Run Individual Applications
19
+ ```bash
20
+ # Wound Classification only
21
+ python app2.py
22
+
23
+ # Depth Estimation only
24
+ python app.py
25
+ ```
26
+
27
+ ## 🏥 Features
28
+
29
+ ### Tab 1: Wound Classification
30
+ - **AI-powered wound type classification**
31
+ - **Grad-CAM visualization** - See which areas the model focuses on
32
+ - **Confidence scores** with color-coded bars
33
+ - **Real-time analysis** - Results update as you upload images
34
+
35
+ ### Tab 2: Depth Estimation & 3D Visualization
36
+ - **Depth map generation** using DepthAnythingV2 model
37
+ - **Interactive 3D point cloud visualization**
38
+ - **Adjustable parameters** (focal length, point density)
39
+ - **Multiple output formats** (grayscale, raw, PLY point cloud)
40
+ - **Image slider comparison** between original and depth map
41
+
42
+ ## 🎨 Interface Features
43
+
44
+ - **Modern dark theme** with gradient backgrounds
45
+ - **Tabbed navigation** between applications
46
+ - **Responsive design** that works on different screen sizes
47
+ - **Professional medical interface** styling
48
+ - **Real-time feedback** and progress indicators
49
+
50
+ ## 📁 File Structure
51
+
52
+ ```
53
+ ├── main_app.py # Unified interface (NEW)
54
+ ├── launcher.py # Application launcher (NEW)
55
+ ├── app.py # Original depth estimation app
56
+ ├── app2.py # Original wound classification app
57
+ ├── checkpoints/
58
+ │ ├── keras_model.h5 # Wound classification model
59
+ │ └── depth_anything_v2_vitl.pth # Depth estimation model
60
+ ├── labels.txt # Wound classification labels
61
+ └── depth_anything_v2/ # Depth model implementation
62
+ ```
63
+
64
+ ## 🔧 Requirements
65
+
66
+ The unified interface requires all the same dependencies as the individual applications:
67
+
68
+ - `gradio`
69
+ - `tensorflow`
70
+ - `torch`
71
+ - `opencv-python`
72
+ - `pillow`
73
+ - `numpy`
74
+ - `matplotlib`
75
+ - `plotly`
76
+ - `open3d`
77
+ - `gradio-imageslider`
78
+
79
+ ## 🌐 Access
80
+
81
+ Once launched, the interface will be available at:
82
+ - **Local**: http://localhost:7860
83
+ - **Public**: A public link will be provided when the server starts
84
+
85
+ ## 💡 Usage Tips
86
+
87
+ ### Wound Classification
88
+ 1. Upload a clear image of the wound
89
+ 2. The model will automatically classify the wound type
90
+ 3. View the Grad-CAM heatmap to see which areas influenced the decision
91
+ 4. Check confidence scores for all possible classifications
92
+
93
+ ### Depth Estimation
94
+ 1. Upload an image for depth analysis
95
+ 2. Adjust the number of 3D points (higher = more detailed but slower)
96
+ 3. Set focal length parameters if you know your camera specs
97
+ 4. Click "Compute Depth" to generate results
98
+ 5. Download depth maps and point clouds as needed
99
+ 6. Explore the interactive 3D visualization
100
+
101
+ ## 🛠️ Troubleshooting
102
+
103
+ ### Model Loading Issues
104
+ If models fail to load, the interface will show appropriate error messages and continue to function with limited capabilities.
105
+
106
+ ### Performance
107
+ - For large images, consider reducing the number of 3D points
108
+ - Depth estimation works best with good lighting and clear subjects
109
+ - Wound classification works best with well-lit, focused images
110
+
111
+ ### Browser Compatibility
112
+ The interface works best with modern browsers (Chrome, Firefox, Safari, Edge).
113
+
114
+ ## 🔄 Navigation
115
+
116
+ You can easily switch between the two main functionalities using the tabs at the top of the interface. Each tab maintains its own state, so you can work on both applications simultaneously.
117
+
118
+ ## 📞 Support
119
+
120
+ If you encounter any issues:
121
+ 1. Check that all required model files are present
122
+ 2. Ensure all dependencies are installed
123
+ 3. Try running individual applications first to isolate issues
124
+ 4. Check the console output for error messages
125
+
126
+ ---
127
+
128
+ **Enjoy using the Medical AI Suite! 🏥✨**
app.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import gradio as gr
3
+ import matplotlib
4
+ import numpy as np
5
+ from PIL import Image
6
+ import torch
7
+ import tempfile
8
+ from gradio_imageslider import ImageSlider
9
+ import plotly.graph_objects as go
10
+ import plotly.express as px
11
+ import open3d as o3d
12
+ from depth_anything_v2.dpt import DepthAnythingV2
13
+
14
+ css = """
15
+ #img-display-container {
16
+ max-height: 100vh;
17
+ }
18
+ #img-display-input {
19
+ max-height: 80vh;
20
+ }
21
+ #img-display-output {
22
+ max-height: 80vh;
23
+ }
24
+ #download {
25
+ height: 62px;
26
+ }
27
+ h1 {
28
+ text-align: center;
29
+ font-size: 3rem;
30
+ font-weight: bold;
31
+ margin: 2rem 0;
32
+ color: #2c3e50;
33
+ }
34
+ """
35
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
36
+ model_configs = {
37
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
38
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
39
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
40
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
41
+ }
42
+ encoder = 'vitl'
43
+ model = DepthAnythingV2(**model_configs[encoder])
44
+ state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
45
+ model.load_state_dict(state_dict)
46
+ model = model.to(DEVICE).eval()
47
+
48
+ title = "Depth Estimation, 3D Visualization"
49
+ description = """Official demo for **Depth Estimation, 3D Visualization**."""
50
+
51
+ def predict_depth(image):
52
+ return model.infer_image(image)
53
+
54
+ def calculate_max_points(image):
55
+ """Calculate maximum points based on image dimensions (3x pixel count)"""
56
+ if image is None:
57
+ return 10000 # Default value
58
+ h, w = image.shape[:2]
59
+ max_points = h * w * 3
60
+ # Ensure minimum and reasonable maximum values
61
+ return max(1000, min(max_points, 1000000))
62
+
63
+ def update_slider_on_image_upload(image):
64
+ """Update the points slider when an image is uploaded"""
65
+ max_points = calculate_max_points(image)
66
+ default_value = min(10000, max_points // 10) # 10% of max points as default
67
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
68
+ label=f"Number of 3D points (max: {max_points:,})")
69
+
70
+ def create_3d_depth_visualization(image, depth_map, max_points=10000):
71
+ """Create an interactive 3D visualization of the depth map"""
72
+ h, w = depth_map.shape
73
+
74
+ # Downsample to avoid too many points for performance
75
+ step = max(1, int(np.sqrt(h * w / max_points)))
76
+
77
+ # Create coordinate grids
78
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
79
+ depth_values = depth_map[::step, ::step]
80
+
81
+ # Flatten arrays
82
+ x_flat = x_coords.flatten()
83
+ y_flat = y_coords.flatten()
84
+ z_flat = depth_values.flatten()
85
+
86
+ # Get corresponding image colors
87
+ image_colors = image[::step, ::step, :]
88
+ colors_flat = image_colors.reshape(-1, 3)
89
+
90
+ # Create 3D scatter plot
91
+ fig = go.Figure(data=[go.Scatter3d(
92
+ x=x_flat,
93
+ y=y_flat,
94
+ z=z_flat,
95
+ mode='markers',
96
+ marker=dict(
97
+ size=2,
98
+ color=colors_flat,
99
+ opacity=0.8
100
+ ),
101
+ hovertemplate='<b>Position:</b> (%{x:.0f}, %{y:.0f})<br>' +
102
+ '<b>Depth:</b> %{z:.2f}<br>' +
103
+ '<extra></extra>'
104
+ )])
105
+
106
+ fig.update_layout(
107
+ title="3D Depth Visualization (Hover to see depth values)",
108
+ scene=dict(
109
+ xaxis_title="X (pixels)",
110
+ yaxis_title="Y (pixels)",
111
+ zaxis_title="Depth",
112
+ camera=dict(
113
+ eye=dict(x=1.5, y=1.5, z=1.5)
114
+ )
115
+ ),
116
+ width=600,
117
+ height=500
118
+ )
119
+
120
+ return fig
121
+
122
+ def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000):
123
+ """Create a point cloud from depth map using camera intrinsics"""
124
+ h, w = depth_map.shape
125
+
126
+ # Downsample to avoid too many points for performance
127
+ step = max(1, int(np.sqrt(h * w / max_points)))
128
+
129
+ # Create mesh grid for camera coordinates
130
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
131
+
132
+ # Convert to camera coordinates (normalized by focal length)
133
+ x_cam = (x_coords - w / 2) / focal_length_x
134
+ y_cam = (y_coords - h / 2) / focal_length_y
135
+
136
+ # Get depth values
137
+ depth_values = depth_map[::step, ::step]
138
+
139
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
140
+ x_3d = x_cam * depth_values
141
+ y_3d = y_cam * depth_values
142
+ z_3d = depth_values
143
+
144
+ # Flatten arrays
145
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
146
+
147
+ # Get corresponding image colors
148
+ image_colors = image[::step, ::step, :]
149
+ colors = image_colors.reshape(-1, 3) / 255.0
150
+
151
+ # Create Open3D point cloud
152
+ pcd = o3d.geometry.PointCloud()
153
+ pcd.points = o3d.utility.Vector3dVector(points)
154
+ pcd.colors = o3d.utility.Vector3dVector(colors)
155
+
156
+ return pcd
157
+
158
+ def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
159
+ """Create an enhanced 3D visualization using proper camera projection"""
160
+ h, w = depth_map.shape
161
+
162
+ # Downsample to avoid too many points for performance
163
+ step = max(1, int(np.sqrt(h * w / max_points)))
164
+
165
+ # Create mesh grid for camera coordinates
166
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
167
+
168
+ # Convert to camera coordinates (normalized by focal length)
169
+ focal_length = 470.4 # Default focal length
170
+ x_cam = (x_coords - w / 2) / focal_length
171
+ y_cam = (y_coords - h / 2) / focal_length
172
+
173
+ # Get depth values
174
+ depth_values = depth_map[::step, ::step]
175
+
176
+ # Calculate 3D points: (x_cam * depth, y_cam * depth, depth)
177
+ x_3d = x_cam * depth_values
178
+ y_3d = y_cam * depth_values
179
+ z_3d = depth_values
180
+
181
+ # Flatten arrays
182
+ x_flat = x_3d.flatten()
183
+ y_flat = y_3d.flatten()
184
+ z_flat = z_3d.flatten()
185
+
186
+ # Get corresponding image colors
187
+ image_colors = image[::step, ::step, :]
188
+ colors_flat = image_colors.reshape(-1, 3)
189
+
190
+ # Create 3D scatter plot with proper camera projection
191
+ fig = go.Figure(data=[go.Scatter3d(
192
+ x=x_flat,
193
+ y=y_flat,
194
+ z=z_flat,
195
+ mode='markers',
196
+ marker=dict(
197
+ size=1.5,
198
+ color=colors_flat,
199
+ opacity=0.9
200
+ ),
201
+ hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
202
+ '<b>Depth:</b> %{z:.2f}<br>' +
203
+ '<extra></extra>'
204
+ )])
205
+
206
+ fig.update_layout(
207
+ title="3D Point Cloud Visualization (Camera Projection)",
208
+ scene=dict(
209
+ xaxis_title="X (meters)",
210
+ yaxis_title="Y (meters)",
211
+ zaxis_title="Z (meters)",
212
+ camera=dict(
213
+ eye=dict(x=2.0, y=2.0, z=2.0),
214
+ center=dict(x=0, y=0, z=0),
215
+ up=dict(x=0, y=0, z=1)
216
+ ),
217
+ aspectmode='data'
218
+ ),
219
+ width=700,
220
+ height=600
221
+ )
222
+
223
+ return fig
224
+
225
+ with gr.Blocks(css=css) as demo:
226
+ gr.HTML(f"<h1>{title}</h1>")
227
+ gr.Markdown(description)
228
+ gr.Markdown("### Depth Prediction demo")
229
+
230
+ with gr.Row():
231
+ input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
232
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
233
+
234
+ with gr.Row():
235
+ submit = gr.Button(value="Compute Depth", variant="primary")
236
+ points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
237
+ label="Number of 3D points (upload image to update max)")
238
+
239
+ with gr.Row():
240
+ focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
241
+ label="Focal Length X (pixels)")
242
+ focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
243
+ label="Focal Length Y (pixels)")
244
+
245
+ with gr.Row():
246
+ gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
247
+ raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
248
+ point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
249
+
250
+ # 3D Visualization
251
+ gr.Markdown("### 3D Point Cloud Visualization")
252
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
253
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
254
+
255
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
256
+
257
+ def on_submit(image, num_points, focal_x, focal_y):
258
+ original_image = image.copy()
259
+
260
+ h, w = image.shape[:2]
261
+
262
+ depth = predict_depth(image[:, :, ::-1])
263
+
264
+ raw_depth = Image.fromarray(depth.astype('uint16'))
265
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
266
+ raw_depth.save(tmp_raw_depth.name)
267
+
268
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
269
+ depth = depth.astype(np.uint8)
270
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
271
+
272
+ gray_depth = Image.fromarray(depth)
273
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
274
+ gray_depth.save(tmp_gray_depth.name)
275
+
276
+ # Create point cloud
277
+ pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points)
278
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
279
+ o3d.io.write_point_cloud(tmp_pointcloud.name, pcd)
280
+
281
+ # Create enhanced 3D visualization
282
+ depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points)
283
+
284
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
285
+
286
+ # Update slider when image is uploaded
287
+ input_image.change(
288
+ fn=update_slider_on_image_upload,
289
+ inputs=[input_image],
290
+ outputs=[points_slider]
291
+ )
292
+
293
+ submit.click(on_submit, inputs=[input_image, points_slider, focal_length_x, focal_length_y],
294
+ outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot])
295
+
296
+
297
+
298
+ if __name__ == '__main__':
299
+ demo.queue().launch()
app2.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing import image as keras_image
6
+ from tensorflow.keras import backend as K
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import io
10
+ import cv2
11
+
12
+ # --- Load model and labels ---
13
+ model = load_model("checkpoints/keras_model.h5")
14
+ with open("labels.txt", "r") as f:
15
+ class_labels = [line.strip() for line in f]
16
+
17
+ # --- Preprocess input ---
18
+ def preprocess_input(img):
19
+ img = img.resize((224, 224))
20
+ arr = keras_image.img_to_array(img)
21
+ arr = arr / 255.0
22
+ return np.expand_dims(arr, axis=0)
23
+
24
+ # --- Enhanced Grad-CAM implementation for Keras ---
25
+ def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
26
+ try:
27
+ # Try to find the specified layer
28
+ target_layer = model.get_layer(last_conv_layer_name)
29
+ except:
30
+ # Fallback: find any convolutional layer
31
+ for layer in model.layers:
32
+ if 'conv' in layer.name.lower():
33
+ target_layer = layer
34
+ break
35
+ else:
36
+ return None
37
+
38
+ grad_model = tf.keras.models.Model(
39
+ [model.inputs], [target_layer.output, model.output]
40
+ )
41
+
42
+ with tf.GradientTape() as tape:
43
+ conv_outputs, predictions = grad_model(img_array)
44
+ loss = predictions[:, class_index]
45
+
46
+ grads = tape.gradient(loss, conv_outputs)[0]
47
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
48
+ conv_outputs = conv_outputs[0]
49
+
50
+ heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
51
+ heatmap = np.maximum(heatmap, 0)
52
+ heatmap = heatmap / np.max(heatmap + K.epsilon())
53
+ return heatmap.numpy()
54
+
55
+ # --- Enhanced Overlay heatmap on image ---
56
+ def overlay_gradcam(original_img, heatmap):
57
+ if heatmap is None:
58
+ return original_img
59
+
60
+ # Resize heatmap
61
+ heatmap = cv2.resize(heatmap, original_img.size)
62
+
63
+ # Normalize safely
64
+ heatmap = np.maximum(heatmap, 0)
65
+ if np.max(heatmap) != 0:
66
+ heatmap /= np.max(heatmap)
67
+ heatmap = np.uint8(255 * heatmap)
68
+
69
+ # Apply JET colormap for better medical visualization
70
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
71
+
72
+ # Convert PIL to array
73
+ original_array = np.array(original_img.convert("RGB"))
74
+
75
+ # Enhanced blend with better contrast
76
+ superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
77
+
78
+ return Image.fromarray(superimposed_img)
79
+
80
+ # --- Enhanced Prediction Function ---
81
+ def classify_and_explain(img):
82
+ if img is None:
83
+ return None, {}, "No image provided"
84
+
85
+ img_array = preprocess_input(img)
86
+ predictions = model.predict(img_array, verbose=0)[0]
87
+ pred_idx = int(np.argmax(predictions))
88
+ pred_class = class_labels[pred_idx]
89
+ confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
90
+
91
+ # Enhanced Grad-CAM
92
+ try:
93
+ heatmap = get_gradcam_heatmap(img_array, model, pred_idx)
94
+ gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
95
+ except Exception as e:
96
+ print(f"Grad-CAM error: {e}")
97
+ gradcam_img = img.resize((224, 224)) # fallback image
98
+
99
+ return gradcam_img, confidence_dict
100
+
101
+ # --- Custom CSS for Dark Mode Medical Interface ---
102
+ css = """
103
+ .gradio-container {
104
+ font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
105
+ background: #1a1a1a;
106
+ min-height: 100vh;
107
+ padding: 20px;
108
+ color: #ffffff;
109
+ }
110
+
111
+ .main-header {
112
+ text-align: center;
113
+ color: white;
114
+ margin-bottom: 2rem;
115
+ padding: 2rem 0;
116
+ }
117
+
118
+ .main-header h1 {
119
+ font-size: 2.5rem;
120
+ margin-bottom: 0.5rem;
121
+ text-shadow: 2px 2px 4px rgba(0,0,0,0.5);
122
+ color: #ffffff;
123
+ }
124
+
125
+ .confidence-bar {
126
+ background: linear-gradient(90deg, #3498db 0%, #2ecc71 100%);
127
+ height: 25px;
128
+ border-radius: 12px;
129
+ margin: 8px 0;
130
+ transition: all 0.3s ease;
131
+ box-shadow: 0 2px 4px rgba(0,0,0,0.3);
132
+ }
133
+
134
+ .confidence-container {
135
+ margin: 15px 0;
136
+ padding: 20px;
137
+ border-radius: 12px;
138
+ background: rgba(255,255,255,0.1);
139
+ backdrop-filter: blur(10px);
140
+ box-shadow: 0 8px 32px rgba(0,0,0,0.3);
141
+ border: 1px solid rgba(255,255,255,0.1);
142
+ }
143
+
144
+ .input-section, .output-section {
145
+ background: rgba(255,255,255,0.05);
146
+ padding: 25px;
147
+ border-radius: 15px;
148
+ margin: 15px;
149
+ backdrop-filter: blur(10px);
150
+ box-shadow: 0 8px 32px rgba(0,0,0,0.3);
151
+ border: 1px solid rgba(255,255,255,0.1);
152
+ }
153
+
154
+ .section-title {
155
+ color: #ffffff;
156
+ font-size: 1.3rem;
157
+ font-weight: 600;
158
+ margin-bottom: 15px;
159
+ border-bottom: 2px solid #3498db;
160
+ padding-bottom: 8px;
161
+ }
162
+
163
+ .gradio-button {
164
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
165
+ border: none;
166
+ color: white;
167
+ padding: 12px 24px;
168
+ border-radius: 25px;
169
+ font-weight: 600;
170
+ transition: all 0.3s ease;
171
+ box-shadow: 0 4px 15px rgba(0,0,0,0.3);
172
+ }
173
+
174
+ .gradio-button:hover {
175
+ transform: translateY(-2px);
176
+ box-shadow: 0 6px 20px rgba(0,0,0,0.4);
177
+ }
178
+
179
+ .gradio-image {
180
+ border-radius: 12px;
181
+ box-shadow: 0 4px 15px rgba(0,0,0,0.3);
182
+ border: 1px solid rgba(255,255,255,0.1);
183
+ }
184
+
185
+ .gradio-textbox, .gradio-number {
186
+ border-radius: 8px;
187
+ border: 2px solid #333333;
188
+ padding: 12px;
189
+ font-size: 1rem;
190
+ background: rgba(255,255,255,0.05);
191
+ color: #ffffff;
192
+ }
193
+
194
+ .gradio-textbox:focus, .gradio-number:focus {
195
+ border-color: #3498db;
196
+ box-shadow: 0 0 0 0.2rem rgba(52,152,219,0.25);
197
+ }
198
+
199
+ .gradio-label {
200
+ color: #ffffff !important;
201
+ }
202
+
203
+ .heatmap-container {
204
+ background: rgba(255,255,255,0.05);
205
+ padding: 15px;
206
+ border-radius: 12px;
207
+ border: 1px solid rgba(255,255,255,0.1);
208
+ margin: 10px 0;
209
+ }
210
+
211
+ .prediction-container {
212
+ background: rgba(52,152,219,0.1);
213
+ padding: 20px;
214
+ border-radius: 12px;
215
+ border-left: 5px solid #3498db;
216
+ margin: 15px 0;
217
+ }
218
+ """
219
+
220
+ # --- Function to create confidence bars HTML ---
221
+ def create_confidence_bars(confidence_dict):
222
+ html_content = "<div class='confidence-container'>"
223
+ for class_name, confidence in confidence_dict.items():
224
+ percentage = confidence * 100
225
+ # Color coding based on confidence
226
+ if percentage > 70:
227
+ color = "#28a745" # Green for high confidence
228
+ elif percentage > 40:
229
+ color = "#ffc107" # Yellow for medium confidence
230
+ else:
231
+ color = "#dc3545" # Red for low confidence
232
+
233
+ html_content += f"""
234
+ <div style='margin: 12px 0;'>
235
+ <div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
236
+ <span style='font-weight: bold; color: {color};'>{class_name}</span>
237
+ <span style='font-weight: bold; color: {color};'>{percentage:.1f}%</span>
238
+ </div>
239
+ <div class='confidence-bar' style='width: {percentage}%; background: {color};'></div>
240
+ </div>
241
+ """
242
+ html_content += "</div>"
243
+ return html_content
244
+
245
+ # --- Enhanced Prediction Function with Dark Mode Interface ---
246
+ def enhanced_classify_and_explain(img):
247
+ if img is None:
248
+ return None, "No image provided", 0, ""
249
+
250
+ gradcam_img, confidence_dict = classify_and_explain(img)
251
+
252
+ # Get predicted class and confidence
253
+ pred_class = max(confidence_dict, key=confidence_dict.get)
254
+ confidence = confidence_dict[pred_class]
255
+
256
+ # Create confidence bars HTML
257
+ confidence_bars_html = create_confidence_bars(confidence_dict)
258
+
259
+ return gradcam_img, pred_class, confidence, confidence_bars_html
260
+
261
+ # --- Enhanced Gradio Interface ---
262
+ with gr.Blocks(css=css, title="Wound Classification") as demo:
263
+ gr.HTML("""
264
+ <div class="main-header">
265
+ <h1>Wound Classification</h1>
266
+ </div>
267
+ """)
268
+
269
+ with gr.Row():
270
+ with gr.Column(scale=1):
271
+ gr.HTML("<div class='section-title'>Input Image</div>")
272
+ input_image = gr.Image(
273
+ label="Upload wound image",
274
+ type="pil",
275
+ height=350,
276
+ container=True
277
+ )
278
+
279
+ with gr.Column(scale=1):
280
+ gr.HTML("<div class='section-title'>Analysis Results</div>")
281
+
282
+ # Prediction results
283
+ prediction_output = gr.Textbox(
284
+ label="Predicted Wound Type",
285
+ interactive=False,
286
+ container=True
287
+ )
288
+
289
+ confidence_output = gr.Number(
290
+ label="Confidence Score",
291
+ interactive=False,
292
+ container=True
293
+ )
294
+
295
+ # Confidence bars for all classes
296
+ confidence_bars = gr.HTML(
297
+ label="Confidence Scores by Class",
298
+ container=True
299
+ )
300
+
301
+ with gr.Row():
302
+ with gr.Column():
303
+ gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
304
+ cam_output = gr.Image(
305
+ label="Grad-CAM Heatmap - Shows which areas the model focused on",
306
+ height=350,
307
+ container=True
308
+ )
309
+
310
+ # Event handlers
311
+ input_image.change(
312
+ fn=enhanced_classify_and_explain,
313
+ inputs=[input_image],
314
+ outputs=[cam_output, prediction_output, confidence_output, confidence_bars]
315
+ )
316
+
317
+ # --- Launch the enhanced interface ---
318
+ if __name__ == "__main__":
319
+ demo.launch(
320
+ server_name="0.0.0.0",
321
+ server_port=7860,
322
+ share=True,
323
+ show_error=True
324
+ )
checkpoints/labels.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 0 Burns
2
+ 1 Surgical Wound
3
+ 2 Traumatic Wound
4
+ 3 Diabetic Foot Ulcer
depth_anything_v2/__pycache__/dinov2.cpython-312.pyc ADDED
Binary file (18.7 kB). View file
 
depth_anything_v2/__pycache__/dpt.cpython-312.pyc ADDED
Binary file (10.6 kB). View file
 
depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (441 Bytes). View file
 
depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-312.pyc ADDED
Binary file (3.95 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/block.cpython-312.pyc ADDED
Binary file (13.1 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-312.pyc ADDED
Binary file (1.65 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-312.pyc ADDED
Binary file (1.42 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-312.pyc ADDED
Binary file (1.85 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-312.pyc ADDED
Binary file (4.06 kB). View file
 
depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-312.pyc ADDED
Binary file (2.84 kB). View file
 
depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError
depth_anything_v2/dinov2_layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super(DropPath, self).__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
depth_anything_v2/dinov2_layers/layer_scale.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
8
+
9
+ from typing import Union
10
+
11
+ import torch
12
+ from torch import Tensor
13
+ from torch import nn
14
+
15
+
16
+ class LayerScale(nn.Module):
17
+ def __init__(
18
+ self,
19
+ dim: int,
20
+ init_values: Union[float, Tensor] = 1e-5,
21
+ inplace: bool = False,
22
+ ) -> None:
23
+ super().__init__()
24
+ self.inplace = inplace
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
depth_anything_v2/dinov2_layers/mlp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+
14
+ from torch import Tensor, nn
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(
19
+ self,
20
+ in_features: int,
21
+ hidden_features: Optional[int] = None,
22
+ out_features: Optional[int] = None,
23
+ act_layer: Callable[..., nn.Module] = nn.GELU,
24
+ drop: float = 0.0,
25
+ bias: bool = True,
26
+ ) -> None:
27
+ super().__init__()
28
+ out_features = out_features or in_features
29
+ hidden_features = hidden_features or in_features
30
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
31
+ self.act = act_layer()
32
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
33
+ self.drop = nn.Dropout(drop)
34
+
35
+ def forward(self, x: Tensor) -> Tensor:
36
+ x = self.fc1(x)
37
+ x = self.act(x)
38
+ x = self.drop(x)
39
+ x = self.fc2(x)
40
+ x = self.drop(x)
41
+ return x
depth_anything_v2/dinov2_layers/patch_embed.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+
13
+ from torch import Tensor
14
+ import torch.nn as nn
15
+
16
+
17
+ def make_2tuple(x):
18
+ if isinstance(x, tuple):
19
+ assert len(x) == 2
20
+ return x
21
+
22
+ assert isinstance(x, int)
23
+ return (x, x)
24
+
25
+
26
+ class PatchEmbed(nn.Module):
27
+ """
28
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
29
+
30
+ Args:
31
+ img_size: Image size.
32
+ patch_size: Patch token size.
33
+ in_chans: Number of input image channels.
34
+ embed_dim: Number of linear projection output channels.
35
+ norm_layer: Normalization layer.
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ img_size: Union[int, Tuple[int, int]] = 224,
41
+ patch_size: Union[int, Tuple[int, int]] = 16,
42
+ in_chans: int = 3,
43
+ embed_dim: int = 768,
44
+ norm_layer: Optional[Callable] = None,
45
+ flatten_embedding: bool = True,
46
+ ) -> None:
47
+ super().__init__()
48
+
49
+ image_HW = make_2tuple(img_size)
50
+ patch_HW = make_2tuple(patch_size)
51
+ patch_grid_size = (
52
+ image_HW[0] // patch_HW[0],
53
+ image_HW[1] // patch_HW[1],
54
+ )
55
+
56
+ self.img_size = image_HW
57
+ self.patch_size = patch_HW
58
+ self.patches_resolution = patch_grid_size
59
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
60
+
61
+ self.in_chans = in_chans
62
+ self.embed_dim = embed_dim
63
+
64
+ self.flatten_embedding = flatten_embedding
65
+
66
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
67
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
68
+
69
+ def forward(self, x: Tensor) -> Tensor:
70
+ _, _, H, W = x.shape
71
+ patch_H, patch_W = self.patch_size
72
+
73
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
74
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
75
+
76
+ x = self.proj(x) # B C H W
77
+ H, W = x.size(2), x.size(3)
78
+ x = x.flatten(2).transpose(1, 2) # B HW C
79
+ x = self.norm(x)
80
+ if not self.flatten_embedding:
81
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
82
+ return x
83
+
84
+ def flops(self) -> float:
85
+ Ho, Wo = self.patches_resolution
86
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
87
+ if self.norm is not None:
88
+ flops += Ho * Wo * self.embed_dim
89
+ return flops
depth_anything_v2/dinov2_layers/swiglu_ffn.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+
9
+ from torch import Tensor, nn
10
+ import torch.nn.functional as F
11
+
12
+
13
+ class SwiGLUFFN(nn.Module):
14
+ def __init__(
15
+ self,
16
+ in_features: int,
17
+ hidden_features: Optional[int] = None,
18
+ out_features: Optional[int] = None,
19
+ act_layer: Callable[..., nn.Module] = None,
20
+ drop: float = 0.0,
21
+ bias: bool = True,
22
+ ) -> None:
23
+ super().__init__()
24
+ out_features = out_features or in_features
25
+ hidden_features = hidden_features or in_features
26
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
27
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
28
+
29
+ def forward(self, x: Tensor) -> Tensor:
30
+ x12 = self.w12(x)
31
+ x1, x2 = x12.chunk(2, dim=-1)
32
+ hidden = F.silu(x1) * x2
33
+ return self.w3(hidden)
34
+
35
+
36
+ try:
37
+ from xformers.ops import SwiGLU
38
+
39
+ XFORMERS_AVAILABLE = True
40
+ except ImportError:
41
+ SwiGLU = SwiGLUFFN
42
+ XFORMERS_AVAILABLE = False
43
+
44
+
45
+ class SwiGLUFFNFused(SwiGLU):
46
+ def __init__(
47
+ self,
48
+ in_features: int,
49
+ hidden_features: Optional[int] = None,
50
+ out_features: Optional[int] = None,
51
+ act_layer: Callable[..., nn.Module] = None,
52
+ drop: float = 0.0,
53
+ bias: bool = True,
54
+ ) -> None:
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
58
+ super().__init__(
59
+ in_features=in_features,
60
+ hidden_features=hidden_features,
61
+ out_features=out_features,
62
+ bias=bias,
63
+ )
depth_anything_v2/dpt.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torchvision.transforms import Compose
6
+
7
+ from .dinov2 import DINOv2
8
+ from .util.blocks import FeatureFusionBlock, _make_scratch
9
+ from .util.transform import Resize, NormalizeImage, PrepareForNet
10
+
11
+
12
+ def _make_fusion_block(features, use_bn, size=None):
13
+ return FeatureFusionBlock(
14
+ features,
15
+ nn.ReLU(False),
16
+ deconv=False,
17
+ bn=use_bn,
18
+ expand=False,
19
+ align_corners=True,
20
+ size=size,
21
+ )
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ def __init__(self, in_feature, out_feature):
26
+ super().__init__()
27
+
28
+ self.conv_block = nn.Sequential(
29
+ nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
30
+ nn.BatchNorm2d(out_feature),
31
+ nn.ReLU(True)
32
+ )
33
+
34
+ def forward(self, x):
35
+ return self.conv_block(x)
36
+
37
+
38
+ class DPTHead(nn.Module):
39
+ def __init__(
40
+ self,
41
+ in_channels,
42
+ features=256,
43
+ use_bn=False,
44
+ out_channels=[256, 512, 1024, 1024],
45
+ use_clstoken=False
46
+ ):
47
+ super(DPTHead, self).__init__()
48
+
49
+ self.use_clstoken = use_clstoken
50
+
51
+ self.projects = nn.ModuleList([
52
+ nn.Conv2d(
53
+ in_channels=in_channels,
54
+ out_channels=out_channel,
55
+ kernel_size=1,
56
+ stride=1,
57
+ padding=0,
58
+ ) for out_channel in out_channels
59
+ ])
60
+
61
+ self.resize_layers = nn.ModuleList([
62
+ nn.ConvTranspose2d(
63
+ in_channels=out_channels[0],
64
+ out_channels=out_channels[0],
65
+ kernel_size=4,
66
+ stride=4,
67
+ padding=0),
68
+ nn.ConvTranspose2d(
69
+ in_channels=out_channels[1],
70
+ out_channels=out_channels[1],
71
+ kernel_size=2,
72
+ stride=2,
73
+ padding=0),
74
+ nn.Identity(),
75
+ nn.Conv2d(
76
+ in_channels=out_channels[3],
77
+ out_channels=out_channels[3],
78
+ kernel_size=3,
79
+ stride=2,
80
+ padding=1)
81
+ ])
82
+
83
+ if use_clstoken:
84
+ self.readout_projects = nn.ModuleList()
85
+ for _ in range(len(self.projects)):
86
+ self.readout_projects.append(
87
+ nn.Sequential(
88
+ nn.Linear(2 * in_channels, in_channels),
89
+ nn.GELU()))
90
+
91
+ self.scratch = _make_scratch(
92
+ out_channels,
93
+ features,
94
+ groups=1,
95
+ expand=False,
96
+ )
97
+
98
+ self.scratch.stem_transpose = None
99
+
100
+ self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
101
+ self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
102
+ self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
103
+ self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
104
+
105
+ head_features_1 = features
106
+ head_features_2 = 32
107
+
108
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
109
+ self.scratch.output_conv2 = nn.Sequential(
110
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
111
+ nn.ReLU(True),
112
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
113
+ nn.ReLU(True),
114
+ nn.Identity(),
115
+ )
116
+
117
+ def forward(self, out_features, patch_h, patch_w):
118
+ out = []
119
+ for i, x in enumerate(out_features):
120
+ if self.use_clstoken:
121
+ x, cls_token = x[0], x[1]
122
+ readout = cls_token.unsqueeze(1).expand_as(x)
123
+ x = self.readout_projects[i](torch.cat((x, readout), -1))
124
+ else:
125
+ x = x[0]
126
+
127
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
128
+
129
+ x = self.projects[i](x)
130
+ x = self.resize_layers[i](x)
131
+
132
+ out.append(x)
133
+
134
+ layer_1, layer_2, layer_3, layer_4 = out
135
+
136
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
137
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
138
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
139
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
140
+
141
+ path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
142
+ path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
143
+ path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
144
+ path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
145
+
146
+ out = self.scratch.output_conv1(path_1)
147
+ out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
148
+ out = self.scratch.output_conv2(out)
149
+
150
+ return out
151
+
152
+
153
+ class DepthAnythingV2(nn.Module):
154
+ def __init__(
155
+ self,
156
+ encoder='vitl',
157
+ features=256,
158
+ out_channels=[256, 512, 1024, 1024],
159
+ use_bn=False,
160
+ use_clstoken=False
161
+ ):
162
+ super(DepthAnythingV2, self).__init__()
163
+
164
+ self.intermediate_layer_idx = {
165
+ 'vits': [2, 5, 8, 11],
166
+ 'vitb': [2, 5, 8, 11],
167
+ 'vitl': [4, 11, 17, 23],
168
+ 'vitg': [9, 19, 29, 39]
169
+ }
170
+
171
+ self.encoder = encoder
172
+ self.pretrained = DINOv2(model_name=encoder)
173
+
174
+ self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
175
+
176
+ def forward(self, x):
177
+ patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
178
+
179
+ features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
180
+
181
+ depth = self.depth_head(features, patch_h, patch_w)
182
+ depth = F.relu(depth)
183
+
184
+ return depth.squeeze(1)
185
+
186
+ @torch.no_grad()
187
+ def infer_image(self, raw_image, input_size=518):
188
+ image, (h, w) = self.image2tensor(raw_image, input_size)
189
+
190
+ depth = self.forward(image)
191
+
192
+ depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
193
+
194
+ return depth.cpu().numpy()
195
+
196
+ def image2tensor(self, raw_image, input_size=518):
197
+ transform = Compose([
198
+ Resize(
199
+ width=input_size,
200
+ height=input_size,
201
+ resize_target=False,
202
+ keep_aspect_ratio=True,
203
+ ensure_multiple_of=14,
204
+ resize_method='lower_bound',
205
+ image_interpolation_method=cv2.INTER_CUBIC,
206
+ ),
207
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
208
+ PrepareForNet(),
209
+ ])
210
+
211
+ h, w = raw_image.shape[:2]
212
+
213
+ image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
214
+
215
+ image = transform({'image': image})['image']
216
+ image = torch.from_numpy(image).unsqueeze(0)
217
+
218
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
219
+ image = image.to(DEVICE)
220
+
221
+ return image, (h, w)
depth_anything_v2/util/__pycache__/blocks.cpython-312.pyc ADDED
Binary file (5.55 kB). View file
 
depth_anything_v2/util/__pycache__/transform.cpython-312.pyc ADDED
Binary file (7.45 kB). View file
 
depth_anything_v2/util/blocks.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def _make_scratch(in_shape, out_shape, groups=1, expand=False):
5
+ scratch = nn.Module()
6
+
7
+ out_shape1 = out_shape
8
+ out_shape2 = out_shape
9
+ out_shape3 = out_shape
10
+ if len(in_shape) >= 4:
11
+ out_shape4 = out_shape
12
+
13
+ if expand:
14
+ out_shape1 = out_shape
15
+ out_shape2 = out_shape * 2
16
+ out_shape3 = out_shape * 4
17
+ if len(in_shape) >= 4:
18
+ out_shape4 = out_shape * 8
19
+
20
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
21
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
22
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
23
+ if len(in_shape) >= 4:
24
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
25
+
26
+ return scratch
27
+
28
+
29
+ class ResidualConvUnit(nn.Module):
30
+ """Residual convolution module.
31
+ """
32
+
33
+ def __init__(self, features, activation, bn):
34
+ """Init.
35
+
36
+ Args:
37
+ features (int): number of features
38
+ """
39
+ super().__init__()
40
+
41
+ self.bn = bn
42
+
43
+ self.groups=1
44
+
45
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
46
+
47
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
48
+
49
+ if self.bn == True:
50
+ self.bn1 = nn.BatchNorm2d(features)
51
+ self.bn2 = nn.BatchNorm2d(features)
52
+
53
+ self.activation = activation
54
+
55
+ self.skip_add = nn.quantized.FloatFunctional()
56
+
57
+ def forward(self, x):
58
+ """Forward pass.
59
+
60
+ Args:
61
+ x (tensor): input
62
+
63
+ Returns:
64
+ tensor: output
65
+ """
66
+
67
+ out = self.activation(x)
68
+ out = self.conv1(out)
69
+ if self.bn == True:
70
+ out = self.bn1(out)
71
+
72
+ out = self.activation(out)
73
+ out = self.conv2(out)
74
+ if self.bn == True:
75
+ out = self.bn2(out)
76
+
77
+ if self.groups > 1:
78
+ out = self.conv_merge(out)
79
+
80
+ return self.skip_add.add(out, x)
81
+
82
+
83
+ class FeatureFusionBlock(nn.Module):
84
+ """Feature fusion block.
85
+ """
86
+
87
+ def __init__(
88
+ self,
89
+ features,
90
+ activation,
91
+ deconv=False,
92
+ bn=False,
93
+ expand=False,
94
+ align_corners=True,
95
+ size=None
96
+ ):
97
+ """Init.
98
+
99
+ Args:
100
+ features (int): number of features
101
+ """
102
+ super(FeatureFusionBlock, self).__init__()
103
+
104
+ self.deconv = deconv
105
+ self.align_corners = align_corners
106
+
107
+ self.groups=1
108
+
109
+ self.expand = expand
110
+ out_features = features
111
+ if self.expand == True:
112
+ out_features = features // 2
113
+
114
+ self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
115
+
116
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
117
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
118
+
119
+ self.skip_add = nn.quantized.FloatFunctional()
120
+
121
+ self.size=size
122
+
123
+ def forward(self, *xs, size=None):
124
+ """Forward pass.
125
+
126
+ Returns:
127
+ tensor: output
128
+ """
129
+ output = xs[0]
130
+
131
+ if len(xs) == 2:
132
+ res = self.resConfUnit1(xs[1])
133
+ output = self.skip_add.add(output, res)
134
+
135
+ output = self.resConfUnit2(output)
136
+
137
+ if (size is None) and (self.size is None):
138
+ modifier = {"scale_factor": 2}
139
+ elif size is None:
140
+ modifier = {"size": self.size}
141
+ else:
142
+ modifier = {"size": size}
143
+
144
+ output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
145
+
146
+ output = self.out_conv(output)
147
+
148
+ return output
depth_anything_v2/util/transform.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+
5
+ class Resize(object):
6
+ """Resize sample to given size (width, height).
7
+ """
8
+
9
+ def __init__(
10
+ self,
11
+ width,
12
+ height,
13
+ resize_target=True,
14
+ keep_aspect_ratio=False,
15
+ ensure_multiple_of=1,
16
+ resize_method="lower_bound",
17
+ image_interpolation_method=cv2.INTER_AREA,
18
+ ):
19
+ """Init.
20
+
21
+ Args:
22
+ width (int): desired output width
23
+ height (int): desired output height
24
+ resize_target (bool, optional):
25
+ True: Resize the full sample (image, mask, target).
26
+ False: Resize image only.
27
+ Defaults to True.
28
+ keep_aspect_ratio (bool, optional):
29
+ True: Keep the aspect ratio of the input sample.
30
+ Output sample might not have the given width and height, and
31
+ resize behaviour depends on the parameter 'resize_method'.
32
+ Defaults to False.
33
+ ensure_multiple_of (int, optional):
34
+ Output width and height is constrained to be multiple of this parameter.
35
+ Defaults to 1.
36
+ resize_method (str, optional):
37
+ "lower_bound": Output will be at least as large as the given size.
38
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
39
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
40
+ Defaults to "lower_bound".
41
+ """
42
+ self.__width = width
43
+ self.__height = height
44
+
45
+ self.__resize_target = resize_target
46
+ self.__keep_aspect_ratio = keep_aspect_ratio
47
+ self.__multiple_of = ensure_multiple_of
48
+ self.__resize_method = resize_method
49
+ self.__image_interpolation_method = image_interpolation_method
50
+
51
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
52
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
53
+
54
+ if max_val is not None and y > max_val:
55
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
56
+
57
+ if y < min_val:
58
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
59
+
60
+ return y
61
+
62
+ def get_size(self, width, height):
63
+ # determine new height and width
64
+ scale_height = self.__height / height
65
+ scale_width = self.__width / width
66
+
67
+ if self.__keep_aspect_ratio:
68
+ if self.__resize_method == "lower_bound":
69
+ # scale such that output size is lower bound
70
+ if scale_width > scale_height:
71
+ # fit width
72
+ scale_height = scale_width
73
+ else:
74
+ # fit height
75
+ scale_width = scale_height
76
+ elif self.__resize_method == "upper_bound":
77
+ # scale such that output size is upper bound
78
+ if scale_width < scale_height:
79
+ # fit width
80
+ scale_height = scale_width
81
+ else:
82
+ # fit height
83
+ scale_width = scale_height
84
+ elif self.__resize_method == "minimal":
85
+ # scale as least as possbile
86
+ if abs(1 - scale_width) < abs(1 - scale_height):
87
+ # fit width
88
+ scale_height = scale_width
89
+ else:
90
+ # fit height
91
+ scale_width = scale_height
92
+ else:
93
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
94
+
95
+ if self.__resize_method == "lower_bound":
96
+ new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
97
+ new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
98
+ elif self.__resize_method == "upper_bound":
99
+ new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
100
+ new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
101
+ elif self.__resize_method == "minimal":
102
+ new_height = self.constrain_to_multiple_of(scale_height * height)
103
+ new_width = self.constrain_to_multiple_of(scale_width * width)
104
+ else:
105
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
106
+
107
+ return (new_width, new_height)
108
+
109
+ def __call__(self, sample):
110
+ width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
111
+
112
+ # resize sample
113
+ sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
114
+
115
+ if self.__resize_target:
116
+ if "depth" in sample:
117
+ sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
118
+
119
+ if "mask" in sample:
120
+ sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
121
+
122
+ return sample
123
+
124
+
125
+ class NormalizeImage(object):
126
+ """Normlize image by given mean and std.
127
+ """
128
+
129
+ def __init__(self, mean, std):
130
+ self.__mean = mean
131
+ self.__std = std
132
+
133
+ def __call__(self, sample):
134
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
135
+
136
+ return sample
137
+
138
+
139
+ class PrepareForNet(object):
140
+ """Prepare sample for usage as network input.
141
+ """
142
+
143
+ def __init__(self):
144
+ pass
145
+
146
+ def __call__(self, sample):
147
+ image = np.transpose(sample["image"], (2, 0, 1))
148
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
149
+
150
+ if "depth" in sample:
151
+ depth = sample["depth"].astype(np.float32)
152
+ sample["depth"] = np.ascontiguousarray(depth)
153
+
154
+ if "mask" in sample:
155
+ sample["mask"] = sample["mask"].astype(np.float32)
156
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
157
+
158
+ return sample
environment.yml ADDED
Binary file (6.93 kB). View file
 
environment_export.yml ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: depth_copy
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - _libgcc_mutex=0.1=conda_forge
7
+ - _openmp_mutex=4.5=2_gnu
8
+ - bzip2=1.0.8=h4bc722e_7
9
+ - ca-certificates=2025.6.15=hbd8a1cb_0
10
+ - ld_impl_linux-64=2.44=h1423503_0
11
+ - libexpat=2.7.0=h5888daf_0
12
+ - libffi=3.4.6=h2dba641_1
13
+ - libgcc=15.1.0=h767d61c_3
14
+ - libgcc-ng=15.1.0=h69a702a_3
15
+ - libgomp=15.1.0=h767d61c_3
16
+ - liblzma=5.8.1=hb9d3cd8_2
17
+ - libnsl=2.0.1=hb9d3cd8_1
18
+ - libsqlite=3.50.2=h6cd9bfd_0
19
+ - libuuid=2.38.1=h0b41bf4_0
20
+ - libxcrypt=4.4.36=hd590300_1
21
+ - libzlib=1.3.1=hb9d3cd8_2
22
+ - ncurses=6.5=h2d0b736_3
23
+ - openssl=3.5.1=h7b32b05_0
24
+ - pip=25.1.1=pyh8b19718_0
25
+ - python=3.12.11=h9e4cc4f_0_cpython
26
+ - readline=8.2=h8c095d6_2
27
+ - setuptools=80.9.0=pyhff2d567_0
28
+ - tk=8.6.13=noxft_hd72426e_102
29
+ - wheel=0.45.1=pyhd8ed1ab_1
30
+ - pip:
31
+ - absl-py==2.3.1
32
+ - addict==2.4.0
33
+ - aiofiles==24.1.0
34
+ - annotated-types==0.7.0
35
+ - anyio==4.9.0
36
+ - asttokens==3.0.0
37
+ - astunparse==1.6.3
38
+ - attrs==25.3.0
39
+ - blinker==1.9.0
40
+ - certifi==2025.6.15
41
+ - charset-normalizer==3.4.2
42
+ - click==8.2.1
43
+ - colorama==0.4.6
44
+ - comm==0.2.2
45
+ - configargparse==1.7.1
46
+ - contourpy==1.3.2
47
+ - cycler==0.12.1
48
+ - dash==3.1.1
49
+ - decorator==5.2.1
50
+ - executing==2.2.0
51
+ - fastapi==0.115.14
52
+ - fastjsonschema==2.21.1
53
+ - ffmpy==0.6.0
54
+ - filelock==3.18.0
55
+ - flask==3.1.1
56
+ - flatbuffers==25.2.10
57
+ - fonttools==4.58.4
58
+ - fsspec==2025.5.1
59
+ - gast==0.6.0
60
+ - google-pasta==0.2.0
61
+ - gradio==5.35.0
62
+ - gradio-client==1.10.4
63
+ - gradio-imageslider==0.0.20
64
+ - groovy==0.1.2
65
+ - grpcio==1.73.1
66
+ - h11==0.16.0
67
+ - h5py==3.14.0
68
+ - hf-xet==1.1.5
69
+ - httpcore==1.0.9
70
+ - httpx==0.28.1
71
+ - huggingface-hub==0.33.2
72
+ - idna==3.10
73
+ - importlib-metadata==8.7.0
74
+ - ipython==9.4.0
75
+ - ipython-pygments-lexers==1.1.1
76
+ - ipywidgets==8.1.7
77
+ - itsdangerous==2.2.0
78
+ - jedi==0.19.2
79
+ - jinja2==3.1.6
80
+ - joblib==1.5.1
81
+ - jsonschema==4.24.0
82
+ - jsonschema-specifications==2025.4.1
83
+ - jupyter-core==5.8.1
84
+ - jupyterlab-widgets==3.0.15
85
+ - keras==3.10.0
86
+ - kiwisolver==1.4.8
87
+ - libclang==18.1.1
88
+ - markdown==3.8.2
89
+ - markdown-it-py==3.0.0
90
+ - markupsafe==3.0.2
91
+ - matplotlib==3.10.3
92
+ - matplotlib-inline==0.1.7
93
+ - mdurl==0.1.2
94
+ - ml-dtypes==0.5.1
95
+ - mpmath==1.3.0
96
+ - namex==0.1.0
97
+ - narwhals==1.45.0
98
+ - nbformat==5.10.4
99
+ - nest-asyncio==1.6.0
100
+ - networkx==3.5
101
+ - numpy==2.1.3
102
+ - nvidia-cublas-cu12==12.6.4.1
103
+ - nvidia-cuda-cupti-cu12==12.6.80
104
+ - nvidia-cuda-nvrtc-cu12==12.6.77
105
+ - nvidia-cuda-runtime-cu12==12.6.77
106
+ - nvidia-cudnn-cu12==9.5.1.17
107
+ - nvidia-cufft-cu12==11.3.0.4
108
+ - nvidia-cufile-cu12==1.11.1.6
109
+ - nvidia-curand-cu12==10.3.7.77
110
+ - nvidia-cusolver-cu12==11.7.1.2
111
+ - nvidia-cusparse-cu12==12.5.4.2
112
+ - nvidia-cusparselt-cu12==0.6.3
113
+ - nvidia-nccl-cu12==2.26.2
114
+ - nvidia-nvjitlink-cu12==12.6.85
115
+ - nvidia-nvtx-cu12==12.6.77
116
+ - open3d==0.19.0
117
+ - opencv-python==4.11.0.86
118
+ - opt-einsum==3.4.0
119
+ - optree==0.16.0
120
+ - orjson==3.10.18
121
+ - packaging==25.0
122
+ - pandas==2.3.0
123
+ - parso==0.8.4
124
+ - pexpect==4.9.0
125
+ - pillow==11.3.0
126
+ - platformdirs==4.3.8
127
+ - plotly==6.2.0
128
+ - prompt-toolkit==3.0.51
129
+ - protobuf==5.29.5
130
+ - ptyprocess==0.7.0
131
+ - pure-eval==0.2.3
132
+ - pydantic==2.11.7
133
+ - pydantic-core==2.33.2
134
+ - pydub==0.25.1
135
+ - pygments==2.19.2
136
+ - pyparsing==3.2.3
137
+ - pyquaternion==0.9.9
138
+ - python-dateutil==2.9.0.post0
139
+ - python-multipart==0.0.20
140
+ - pytz==2025.2
141
+ - pyyaml==6.0.2
142
+ - referencing==0.36.2
143
+ - requests==2.32.4
144
+ - retrying==1.4.0
145
+ - rich==14.0.0
146
+ - rpds-py==0.26.0
147
+ - ruff==0.12.1
148
+ - safehttpx==0.1.6
149
+ - scikit-learn==1.7.0
150
+ - scipy==1.16.0
151
+ - semantic-version==2.10.0
152
+ - shellingham==1.5.4
153
+ - six==1.17.0
154
+ - sniffio==1.3.1
155
+ - stack-data==0.6.3
156
+ - starlette==0.46.2
157
+ - sympy==1.14.0
158
+ - tensorboard==2.19.0
159
+ - tensorboard-data-server==0.7.2
160
+ - tensorflow==2.19.0
161
+ - termcolor==3.1.0
162
+ - threadpoolctl==3.6.0
163
+ - tomlkit==0.13.3
164
+ - torch==2.7.1
165
+ - torchaudio==2.7.1
166
+ - torchvision==0.22.1
167
+ - tqdm==4.67.1
168
+ - traitlets==5.14.3
169
+ - triton==3.3.1
170
+ - typer==0.16.0
171
+ - typing-extensions==4.14.0
172
+ - typing-inspection==0.4.1
173
+ - tzdata==2025.2
174
+ - urllib3==2.5.0
175
+ - uvicorn==0.35.0
176
+ - wcwidth==0.2.13
177
+ - websockets==15.0.1
178
+ - werkzeug==3.1.3
179
+ - widgetsnbextension==4.0.14
180
+ - wrapt==1.17.2
181
+ - zipp==3.23.0
182
+ prefix: /home/uphen/anaconda3/envs/depth_copy
environment_from_history.yml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: depth_copy
2
+ channels:
3
+ - defaults
4
+ dependencies:
5
+ - conda-forge/linux-64::_libgcc_mutex==0.1=conda_forge
6
+ - conda-forge/noarch::ca-certificates==2025.6.15=hbd8a1cb_0
7
+ - conda-forge/linux-64::ld_impl_linux-64==2.44=h1423503_0
8
+ - conda-forge/linux-64::libgomp==15.1.0=h767d61c_3
9
+ - conda-forge/noarch::tzdata==2025b=h78e105d_0
10
+ - conda-forge/linux-64::_openmp_mutex==4.5=2_gnu
11
+ - conda-forge/linux-64::libgcc==15.1.0=h767d61c_3
12
+ - conda-forge/linux-64::libexpat==2.7.0=h5888daf_0
13
+ - conda-forge/linux-64::libffi==3.4.6=h2dba641_1
14
+ - conda-forge/linux-64::libgcc-ng==15.1.0=h69a702a_3
15
+ - conda-forge/linux-64::liblzma==5.8.1=hb9d3cd8_2
16
+ - conda-forge/linux-64::libnsl==2.0.1=hb9d3cd8_1
17
+ - conda-forge/linux-64::libzlib==1.3.1=hb9d3cd8_2
18
+ - conda-forge/linux-64::ncurses==6.5=h2d0b736_3
19
+ - conda-forge/linux-64::openssl==3.5.1=h7b32b05_0
20
+ - conda-forge/linux-64::bzip2==1.0.8=h4bc722e_7
21
+ - conda-forge/linux-64::libsqlite==3.50.2=h6cd9bfd_0
22
+ - conda-forge/linux-64::libuuid==2.38.1=h0b41bf4_0
23
+ - conda-forge/linux-64::libxcrypt==4.4.36=hd590300_1
24
+ - conda-forge/linux-64::readline==8.2=h8c095d6_2
25
+ - conda-forge/linux-64::tk==8.6.13=noxft_hd72426e_102
26
+ - conda-forge/linux-64::python==3.12.11=h9e4cc4f_0_cpython
27
+ - conda-forge/noarch::setuptools==80.9.0=pyhff2d567_0
28
+ - conda-forge/noarch::wheel==0.45.1=pyhd8ed1ab_1
29
+ - conda-forge/noarch::pip==25.1.1=pyh8b19718_0
30
+ prefix: /home/uphen/anaconda3/envs/depth_copy
environment_linux.yml ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: depth
2
+ channels:
3
+ - conda-forge
4
+ - pytorch
5
+ - defaults
6
+ dependencies:
7
+ - python=3.12
8
+ - pip
9
+ - pip:
10
+ - aiofiles==24.1.0
11
+ - annotated-types==0.7.0
12
+ - anyio==4.9.0
13
+ - asttokens==3.0.0
14
+ - attrs==25.3.0
15
+ - blinker==1.9.0
16
+ - certifi==2025.6.15
17
+ - charset-normalizer==3.4.2
18
+ - click==8.2.1
19
+ - colorama==0.4.6
20
+ - comm==0.2.2
21
+ - configargparse==1.7.1
22
+ - contourpy==1.3.2
23
+ - cycler==0.12.1
24
+ - dash==3.1.1
25
+ - decorator==5.2.1
26
+ - executing==2.2.0
27
+ - fastapi==0.115.14
28
+ - fastjsonschema==2.21.1
29
+ - ffmpy==0.6.0
30
+ - filelock==3.18.0
31
+ - flask==3.1.1
32
+ - fonttools==4.58.4
33
+ - fsspec==2025.5.1
34
+ - gradio==5.35.0
35
+ - gradio-client==1.10.4
36
+ - gradio-imageslider==0.0.20
37
+ - groovy==0.1.2
38
+ - h11==0.16.0
39
+ - httpcore==1.0.9
40
+ - httpx==0.28.1
41
+ - huggingface-hub==0.33.2
42
+ - idna==3.10
43
+ - importlib-metadata==8.7.0
44
+ - ipython==9.4.0
45
+ - ipython-pygments-lexers==1.1.1
46
+ - ipywidgets==8.1.7
47
+ - itsdangerous==2.2.0
48
+ - jedi==0.19.2
49
+ - jinja2==3.1.6
50
+ - jsonschema==4.24.0
51
+ - jsonschema-specifications==2025.4.1
52
+ - jupyter-core==5.8.1
53
+ - jupyterlab-widgets==3.0.15
54
+ - kiwisolver==1.4.8
55
+ - markdown-it-py==3.0.0
56
+ - markupsafe==3.0.2
57
+ - matplotlib==3.10.3
58
+ - matplotlib-inline==0.1.7
59
+ - mdurl==0.1.2
60
+ - mpmath==1.3.0
61
+ - narwhals==1.45.0
62
+ - nbformat==5.10.4
63
+ - nest-asyncio==1.6.0
64
+ - networkx==3.5
65
+ - numpy==2.3.1
66
+ - open3d==0.19.0
67
+ - opencv-python==4.11.0.86
68
+ - orjson==3.10.18
69
+ - packaging==25.0
70
+ - pandas==2.3.0
71
+ - parso==0.8.4
72
+ - pillow==11.3.0
73
+ - platformdirs==4.3.8
74
+ - plotly==6.2.0
75
+ - prompt-toolkit==3.0.51
76
+ - pure-eval==0.2.3
77
+ - pydantic==2.11.7
78
+ - pydantic-core==2.33.2
79
+ - pydub==0.25.1
80
+ - pygments==2.19.2
81
+ - pyparsing==3.2.3
82
+ - python-dateutil==2.9.0.post0
83
+ - python-multipart==0.0.20
84
+ - pytz==2025.2
85
+ - pyyaml==6.0.2
86
+ - referencing==0.36.2
87
+ - requests==2.32.4
88
+ - retrying==1.4.0
89
+ - rich==14.0.0
90
+ - rpds-py==0.26.0
91
+ - ruff==0.12.1
92
+ - safehttpx==0.1.6
93
+ - semantic-version==2.10.0
94
+ - shellingham==1.5.4
95
+ - six==1.17.0
96
+ - sniffio==1.3.1
97
+ - stack-data==0.6.3
98
+ - starlette==0.46.2
99
+ - sympy==1.14.0
100
+ - tomlkit==0.13.3
101
+ - torch==2.7.1
102
+ - torchaudio==2.7.1
103
+ - torchvision==0.22.1
104
+ - tqdm==4.67.1
105
+ - traitlets==5.14.3
106
+ - typer==0.16.0
107
+ - typing-extensions==4.14.0
108
+ - typing-inspection==0.4.1
109
+ - tzdata==2025.2
110
+ - urllib3==2.5.0
111
+ - uvicorn==0.35.0
112
+ - wcwidth==0.2.13
113
+ - websockets==15.0.1
114
+ - werkzeug==3.1.3
115
+ - widgetsnbextension==4.0.14
116
+ - zipp==3.23.0
keras_model_3.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4669439786977b2098b89ad88387538d8085bd129e1e6b9de7f7821aa8baa3fa
3
+ size 2453440
labels.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 0 Burns
2
+ 1 Surgical Wou...
3
+ 2 Traumatic Wo...
4
+ 3 Diabetic Foo...
main_app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing import image as keras_image
6
+ from tensorflow.keras import backend as K
7
+ import matplotlib.pyplot as plt
8
+ from PIL import Image
9
+ import io
10
+ import cv2
11
+ import glob
12
+ import matplotlib
13
+ import torch
14
+ import tempfile
15
+ from gradio_imageslider import ImageSlider
16
+ import plotly.graph_objects as go
17
+ import plotly.express as px
18
+ import open3d as o3d
19
+ from depth_anything_v2.dpt import DepthAnythingV2
20
+
21
+ # --- Load models ---
22
+ # Wound classification model
23
+ try:
24
+ wound_model = load_model("checkpoints/keras_model.h5")
25
+ with open("labels.txt", "r") as f:
26
+ class_labels = [line.strip() for line in f]
27
+ except:
28
+ wound_model = None
29
+ class_labels = ["No model found"]
30
+
31
+ # Depth estimation model
32
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
33
+ model_configs = {
34
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
35
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
36
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
37
+ 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
38
+ }
39
+ encoder = 'vitl'
40
+ try:
41
+ depth_model = DepthAnythingV2(**model_configs[encoder])
42
+ state_dict = torch.load(f'checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu")
43
+ depth_model.load_state_dict(state_dict)
44
+ depth_model = depth_model.to(DEVICE).eval()
45
+ except:
46
+ depth_model = None
47
+
48
+ # --- Wound Classification Functions ---
49
+ def preprocess_input(img):
50
+ img = img.resize((224, 224))
51
+ arr = keras_image.img_to_array(img)
52
+ arr = arr / 255.0
53
+ return np.expand_dims(arr, axis=0)
54
+
55
+ def get_gradcam_heatmap(img_array, model, class_index, last_conv_layer_name="conv5_block3_out"):
56
+ try:
57
+ target_layer = model.get_layer(last_conv_layer_name)
58
+ except:
59
+ for layer in model.layers:
60
+ if 'conv' in layer.name.lower():
61
+ target_layer = layer
62
+ break
63
+ else:
64
+ return None
65
+
66
+ grad_model = tf.keras.models.Model(
67
+ [model.inputs], [target_layer.output, model.output]
68
+ )
69
+
70
+ with tf.GradientTape() as tape:
71
+ conv_outputs, predictions = grad_model(img_array)
72
+ loss = predictions[:, class_index]
73
+
74
+ grads = tape.gradient(loss, conv_outputs)
75
+ if grads is None:
76
+ return None
77
+
78
+ grads = grads[0]
79
+ pooled_grads = tf.reduce_mean(grads, axis=(0, 1, 2))
80
+ conv_outputs = conv_outputs[0]
81
+
82
+ heatmap = tf.reduce_sum(tf.multiply(pooled_grads, conv_outputs), axis=-1)
83
+ heatmap = np.maximum(heatmap, 0)
84
+ heatmap = heatmap / np.max(heatmap + K.epsilon())
85
+ return heatmap.numpy()
86
+
87
+ def overlay_gradcam(original_img, heatmap):
88
+ if heatmap is None:
89
+ return original_img
90
+
91
+ heatmap = cv2.resize(heatmap, original_img.size)
92
+ heatmap = np.maximum(heatmap, 0)
93
+ if np.max(heatmap) != 0:
94
+ heatmap /= np.max(heatmap)
95
+ heatmap = np.uint8(255 * heatmap)
96
+
97
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
98
+ original_array = np.array(original_img.convert("RGB"))
99
+ superimposed_img = cv2.addWeighted(original_array, 0.6, heatmap_color, 0.4, 0)
100
+
101
+ return Image.fromarray(superimposed_img)
102
+
103
+ def classify_and_explain(img):
104
+ if img is None or wound_model is None:
105
+ return None, {}, "No image provided or model not available"
106
+
107
+ img_array = preprocess_input(img)
108
+ predictions = wound_model.predict(img_array, verbose=0)[0]
109
+ pred_idx = int(np.argmax(predictions))
110
+ pred_class = class_labels[pred_idx]
111
+ confidence_dict = {class_labels[i]: float(predictions[i]) for i in range(len(class_labels))}
112
+
113
+ try:
114
+ heatmap = get_gradcam_heatmap(img_array, wound_model, pred_idx)
115
+ gradcam_img = overlay_gradcam(img.resize((224, 224)), heatmap)
116
+ except Exception as e:
117
+ print(f"Grad-CAM error: {e}")
118
+ gradcam_img = img.resize((224, 224))
119
+
120
+ return gradcam_img, confidence_dict
121
+
122
+ def create_confidence_bars(confidence_dict):
123
+ html_content = "<div class='confidence-container'>"
124
+ for class_name, confidence in confidence_dict.items():
125
+ percentage = confidence * 100
126
+ if percentage > 70:
127
+ css_class = "confidence-high"
128
+ elif percentage > 40:
129
+ css_class = "confidence-medium"
130
+ else:
131
+ css_class = "confidence-low"
132
+
133
+ html_content += f"""
134
+ <div style='margin: 12px 0;'>
135
+ <div style='display: flex; justify-content: space-between; margin-bottom: 8px;'>
136
+ <span style='font-weight: bold;'>{class_name}</span>
137
+ <span style='font-weight: bold;'>{percentage:.1f}%</span>
138
+ </div>
139
+ <div class='confidence-bar {css_class}' style='width: {percentage}%;'></div>
140
+ </div>
141
+ """
142
+ html_content += "</div>"
143
+ return html_content
144
+
145
+ def enhanced_classify_and_explain(img):
146
+ if img is None:
147
+ return None, "No image provided", 0, ""
148
+
149
+ gradcam_img, confidence_dict = classify_and_explain(img)
150
+
151
+ if isinstance(confidence_dict, str): # Error case
152
+ return None, confidence_dict, 0, ""
153
+
154
+ pred_class = max(confidence_dict, key=confidence_dict.get)
155
+ confidence = confidence_dict[pred_class]
156
+ confidence_bars_html = create_confidence_bars(confidence_dict)
157
+
158
+ return gradcam_img, pred_class, confidence, confidence_bars_html
159
+
160
+ # --- Depth Estimation Functions ---
161
+ def predict_depth(image):
162
+ if depth_model is None:
163
+ return None
164
+ return depth_model.infer_image(image)
165
+
166
+ def calculate_max_points(image):
167
+ if image is None:
168
+ return 10000
169
+ h, w = image.shape[:2]
170
+ max_points = h * w * 3
171
+ return max(1000, min(max_points, 1000000))
172
+
173
+ def update_slider_on_image_upload(image):
174
+ max_points = calculate_max_points(image)
175
+ default_value = min(10000, max_points // 10)
176
+ return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000,
177
+ label=f"Number of 3D points (max: {max_points:,})")
178
+
179
+ def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000):
180
+ h, w = depth_map.shape
181
+ step = max(1, int(np.sqrt(h * w / max_points)))
182
+
183
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
184
+ x_cam = (x_coords - w / 2) / focal_length_x
185
+ y_cam = (y_coords - h / 2) / focal_length_y
186
+
187
+ depth_values = depth_map[::step, ::step]
188
+ x_3d = x_cam * depth_values
189
+ y_3d = y_cam * depth_values
190
+ z_3d = depth_values
191
+
192
+ points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1)
193
+ image_colors = image[::step, ::step, :]
194
+ colors = image_colors.reshape(-1, 3) / 255.0
195
+
196
+ pcd = o3d.geometry.PointCloud()
197
+ pcd.points = o3d.utility.Vector3dVector(points)
198
+ pcd.colors = o3d.utility.Vector3dVector(colors)
199
+
200
+ return pcd
201
+
202
+ def create_enhanced_3d_visualization(image, depth_map, max_points=10000):
203
+ h, w = depth_map.shape
204
+ step = max(1, int(np.sqrt(h * w / max_points)))
205
+
206
+ y_coords, x_coords = np.mgrid[0:h:step, 0:w:step]
207
+ focal_length = 470.4
208
+ x_cam = (x_coords - w / 2) / focal_length
209
+ y_cam = (y_coords - h / 2) / focal_length
210
+
211
+ depth_values = depth_map[::step, ::step]
212
+ x_3d = x_cam * depth_values
213
+ y_3d = y_cam * depth_values
214
+ z_3d = depth_values
215
+
216
+ x_flat = x_3d.flatten()
217
+ y_flat = y_3d.flatten()
218
+ z_flat = z_3d.flatten()
219
+
220
+ image_colors = image[::step, ::step, :]
221
+ colors_flat = image_colors.reshape(-1, 3)
222
+
223
+ fig = go.Figure(data=[go.Scatter3d(
224
+ x=x_flat,
225
+ y=y_flat,
226
+ z=z_flat,
227
+ mode='markers',
228
+ marker=dict(
229
+ size=1.5,
230
+ color=colors_flat,
231
+ opacity=0.9
232
+ ),
233
+ hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' +
234
+ '<b>Depth:</b> %{z:.2f}<br>' +
235
+ '<extra></extra>'
236
+ )])
237
+
238
+ fig.update_layout(
239
+ title="3D Point Cloud Visualization (Camera Projection)",
240
+ scene=dict(
241
+ xaxis_title="X (meters)",
242
+ yaxis_title="Y (meters)",
243
+ zaxis_title="Z (meters)",
244
+ camera=dict(
245
+ eye=dict(x=2.0, y=2.0, z=2.0),
246
+ center=dict(x=0, y=0, z=0),
247
+ up=dict(x=0, y=0, z=1)
248
+ ),
249
+ aspectmode='data'
250
+ ),
251
+ width=700,
252
+ height=600
253
+ )
254
+
255
+ return fig
256
+
257
+ def on_depth_submit(image, num_points, focal_x, focal_y):
258
+ if image is None or depth_model is None:
259
+ return None, None, None, None, None
260
+
261
+ original_image = image.copy()
262
+ h, w = image.shape[:2]
263
+ depth = predict_depth(image[:, :, ::-1])
264
+
265
+ if depth is None:
266
+ return None, None, None, None, None
267
+
268
+ raw_depth = Image.fromarray(depth.astype('uint16'))
269
+ tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
270
+ raw_depth.save(tmp_raw_depth.name)
271
+
272
+ depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
273
+ depth = depth.astype(np.uint8)
274
+ cmap = matplotlib.colormaps.get_cmap('Spectral_r')
275
+ colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8)
276
+
277
+ gray_depth = Image.fromarray(depth)
278
+ tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
279
+ gray_depth.save(tmp_gray_depth.name)
280
+
281
+ pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points)
282
+ tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False)
283
+ o3d.io.write_point_cloud(tmp_pointcloud.name, pcd)
284
+
285
+ depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points)
286
+
287
+ return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d]
288
+
289
+ # --- Custom CSS for Unified Interface ---
290
+ css = """
291
+ /* Minimal dark theme styling */
292
+ .main-header {
293
+ text-align: center;
294
+ margin-bottom: 2rem;
295
+ padding: 2rem 0;
296
+ }
297
+
298
+ .main-header h1 {
299
+ font-size: 2.5rem;
300
+ margin-bottom: 0.5rem;
301
+ font-weight: 600;
302
+ }
303
+
304
+ .main-header p {
305
+ font-size: 1.1rem;
306
+ opacity: 0.8;
307
+ }
308
+
309
+ .section-title {
310
+ font-size: 1.2rem;
311
+ font-weight: 600;
312
+ margin-bottom: 15px;
313
+ padding-bottom: 8px;
314
+ border-bottom: 1px solid var(--border-color-primary);
315
+ }
316
+
317
+ .confidence-container {
318
+ margin: 15px 0;
319
+ padding: 15px;
320
+ border-radius: 8px;
321
+ background: var(--background-secondary);
322
+ border: 1px solid var(--border-color-primary);
323
+ }
324
+
325
+ .confidence-bar {
326
+ height: 20px;
327
+ border-radius: 4px;
328
+ margin: 6px 0;
329
+ background: var(--primary-500);
330
+ transition: width 0.3s ease;
331
+ }
332
+
333
+ /* Simple confidence bar colors */
334
+ .confidence-high {
335
+ background: var(--success-500);
336
+ }
337
+
338
+ .confidence-medium {
339
+ background: var(--warning-500);
340
+ }
341
+
342
+ .confidence-low {
343
+ background: var(--error-500);
344
+ }
345
+
346
+ /* Minimal spacing and layout */
347
+ .gradio-container {
348
+ max-width: 100%;
349
+ margin: 0;
350
+ padding: 20px;
351
+ width: 100%;
352
+ }
353
+
354
+ /* Clean image styling */
355
+ .gradio-image {
356
+ border-radius: 8px;
357
+ border: 1px solid var(--border-color-primary);
358
+ }
359
+
360
+ /* Simple button styling */
361
+ .gradio-button {
362
+ border-radius: 6px;
363
+ font-weight: 500;
364
+ }
365
+
366
+ /* Clean form elements */
367
+ .gradio-textbox, .gradio-number, .gradio-slider {
368
+ border-radius: 6px;
369
+ border: 1px solid var(--border-color-primary);
370
+ }
371
+
372
+ /* Tab styling */
373
+ .gradio-tabs {
374
+ border-radius: 8px;
375
+ overflow: hidden;
376
+ }
377
+
378
+ /* File upload styling */
379
+ .gradio-file {
380
+ border-radius: 6px;
381
+ border: 1px solid var(--border-color-primary);
382
+ }
383
+
384
+ /* Plot styling */
385
+ .gradio-plot {
386
+ border-radius: 8px;
387
+ border: 1px solid var(--border-color-primary);
388
+ }
389
+
390
+ /* Full width and height layout */
391
+ body, html {
392
+ margin: 0;
393
+ padding: 0;
394
+ width: 100%;
395
+ height: 100%;
396
+ }
397
+
398
+ #root {
399
+ width: 100%;
400
+ height: 100%;
401
+ }
402
+
403
+ /* Ensure Gradio uses full width */
404
+ .gradio-container {
405
+ min-height: 100vh;
406
+ }
407
+
408
+ /* Responsive adjustments */
409
+ @media (max-width: 768px) {
410
+ .main-header h1 {
411
+ font-size: 2rem;
412
+ }
413
+
414
+ .gradio-container {
415
+ padding: 10px;
416
+ }
417
+ }
418
+ """
419
+
420
+ # --- Create Unified Interface ---
421
+ with gr.Blocks(css=css, title="Medical AI Suite") as demo:
422
+ gr.HTML("""
423
+ <div class="main-header">
424
+ <h1>Medical AI Suite</h1>
425
+ <p>Advanced AI-powered medical image analysis and 3D visualization</p>
426
+ </div>
427
+ """)
428
+
429
+ with gr.Tabs() as tabs:
430
+ # Tab 1: Wound Classification
431
+ with gr.TabItem("Wound Classification", id=0):
432
+ gr.HTML("<div class='section-title'>Wound Classification with Grad-CAM Visualization</div>")
433
+
434
+ with gr.Row():
435
+ with gr.Column(scale=1):
436
+ gr.HTML("<div class='section-title'>Input Image</div>")
437
+ wound_input_image = gr.Image(
438
+ label="Upload wound image",
439
+ type="pil",
440
+ height=350,
441
+ container=True
442
+ )
443
+
444
+ with gr.Column(scale=1):
445
+ gr.HTML("<div class='section-title'>Analysis Results</div>")
446
+ wound_prediction_output = gr.Textbox(
447
+ label="Predicted Wound Type",
448
+ interactive=False,
449
+ container=True
450
+ )
451
+ wound_confidence_output = gr.Number(
452
+ label="Confidence Score",
453
+ interactive=False,
454
+ container=True
455
+ )
456
+ wound_confidence_bars = gr.HTML(
457
+ label="Confidence Scores by Class",
458
+ container=True
459
+ )
460
+
461
+ with gr.Row():
462
+ with gr.Column():
463
+ gr.HTML("<div class='section-title'>Model Focus Visualization</div>")
464
+ wound_cam_output = gr.Image(
465
+ label="Grad-CAM Heatmap - Shows which areas the model focused on",
466
+ height=350,
467
+ container=True
468
+ )
469
+
470
+ # Event handlers for wound classification
471
+ wound_input_image.change(
472
+ fn=enhanced_classify_and_explain,
473
+ inputs=[wound_input_image],
474
+ outputs=[wound_cam_output, wound_prediction_output, wound_confidence_output, wound_confidence_bars]
475
+ )
476
+
477
+ # Tab 2: Depth Estimation
478
+ with gr.TabItem("Depth Estimation & 3D Visualization", id=1):
479
+ gr.HTML("<div class='section-title'>Depth Estimation and 3D Point Cloud Generation</div>")
480
+
481
+ with gr.Row():
482
+ depth_input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input')
483
+ depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output')
484
+
485
+ with gr.Row():
486
+ depth_submit = gr.Button(value="Compute Depth", variant="primary")
487
+ depth_points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000,
488
+ label="Number of 3D points (upload image to update max)")
489
+
490
+ with gr.Row():
491
+ depth_focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
492
+ label="Focal Length X (pixels)")
493
+ depth_focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10,
494
+ label="Focal Length Y (pixels)")
495
+
496
+ with gr.Row():
497
+ depth_gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download")
498
+ depth_raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download")
499
+ depth_point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download")
500
+
501
+ gr.Markdown("### 3D Point Cloud Visualization")
502
+ gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.")
503
+ depth_3d_plot = gr.Plot(label="3D Point Cloud")
504
+
505
+ # Event handlers for depth estimation
506
+ depth_input_image.change(
507
+ fn=update_slider_on_image_upload,
508
+ inputs=[depth_input_image],
509
+ outputs=[depth_points_slider]
510
+ )
511
+
512
+ depth_submit.click(
513
+ on_depth_submit,
514
+ inputs=[depth_input_image, depth_points_slider, depth_focal_length_x, depth_focal_length_y],
515
+ outputs=[depth_image_slider, depth_gray_depth_file, depth_raw_file, depth_point_cloud_file, depth_3d_plot]
516
+ )
517
+
518
+ # Cross-tab image sharing functionality
519
+ # When image is uploaded in wound classification, also update depth estimation
520
+ wound_input_image.change(
521
+ fn=lambda img: img,
522
+ inputs=[wound_input_image],
523
+ outputs=[depth_input_image]
524
+ )
525
+
526
+ # When image is uploaded in depth estimation, also update wound classification
527
+ depth_input_image.change(
528
+ fn=lambda img: img,
529
+ inputs=[depth_input_image],
530
+ outputs=[wound_input_image]
531
+ )
532
+
533
+ # --- Launch the unified interface ---
534
+ if __name__ == "__main__":
535
+ demo.queue().launch(
536
+ server_name="0.0.0.0",
537
+ server_port=7860,
538
+ share=True,
539
+ show_error=True
540
+ )
metric_depth/README.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Depth Anything V2 for Metric Depth Estimation
2
+
3
+ ![teaser](./assets/compare_zoedepth.png)
4
+
5
+ We here provide a simple codebase to fine-tune our Depth Anything V2 pre-trained encoder for metric depth estimation. Built on our powerful encoder, we use a simple DPT head to regress the depth. We fine-tune our pre-trained encoder on synthetic Hypersim / Virtual KITTI datasets for indoor / outdoor metric depth estimation, respectively.
6
+
7
+
8
+ # Pre-trained Models
9
+
10
+ We provide **six metric depth models** of three scales for indoor and outdoor scenes, respectively.
11
+
12
+ | Base Model | Params | Indoor (Hypersim) | Outdoor (Virtual KITTI 2) |
13
+ |:-|-:|:-:|:-:|
14
+ | Depth-Anything-V2-Small | 24.8M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Small/resolve/main/depth_anything_v2_metric_hypersim_vits.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Small/resolve/main/depth_anything_v2_metric_vkitti_vits.pth?download=true) |
15
+ | Depth-Anything-V2-Base | 97.5M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Base/resolve/main/depth_anything_v2_metric_hypersim_vitb.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Base/resolve/main/depth_anything_v2_metric_vkitti_vitb.pth?download=true) |
16
+ | Depth-Anything-V2-Large | 335.3M | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-Hypersim-Large/resolve/main/depth_anything_v2_metric_hypersim_vitl.pth?download=true) | [Download](https://huggingface.co/depth-anything/Depth-Anything-V2-Metric-VKITTI-Large/resolve/main/depth_anything_v2_metric_vkitti_vitl.pth?download=true) |
17
+
18
+ *We recommend to first try our larger models (if computational cost is affordable) and the indoor version.*
19
+
20
+ ## Usage
21
+
22
+ ### Prepraration
23
+
24
+ ```bash
25
+ git clone https://github.com/DepthAnything/Depth-Anything-V2
26
+ cd Depth-Anything-V2/metric_depth
27
+ pip install -r requirements.txt
28
+ ```
29
+
30
+ Download the checkpoints listed [here](#pre-trained-models) and put them under the `checkpoints` directory.
31
+
32
+ ### Use our models
33
+ ```python
34
+ import cv2
35
+ import torch
36
+
37
+ from depth_anything_v2.dpt import DepthAnythingV2
38
+
39
+ model_configs = {
40
+ 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
41
+ 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
42
+ 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}
43
+ }
44
+
45
+ encoder = 'vitl' # or 'vits', 'vitb'
46
+ dataset = 'hypersim' # 'hypersim' for indoor model, 'vkitti' for outdoor model
47
+ max_depth = 20 # 20 for indoor model, 80 for outdoor model
48
+
49
+ model = DepthAnythingV2(**{**model_configs[encoder], 'max_depth': max_depth})
50
+ model.load_state_dict(torch.load(f'checkpoints/depth_anything_v2_metric_{dataset}_{encoder}.pth', map_location='cpu'))
51
+ model.eval()
52
+
53
+ raw_img = cv2.imread('your/image/path')
54
+ depth = model.infer_image(raw_img) # HxW depth map in meters in numpy
55
+ ```
56
+
57
+ ### Running script on images
58
+
59
+ Here, we take the `vitl` encoder as an example. You can also use `vitb` or `vits` encoders.
60
+
61
+ ```bash
62
+ # indoor scenes
63
+ python run.py \
64
+ --encoder vitl \
65
+ --load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
66
+ --max-depth 20 \
67
+ --img-path <path> --outdir <outdir> [--input-size <size>] [--save-numpy]
68
+
69
+ # outdoor scenes
70
+ python run.py \
71
+ --encoder vitl \
72
+ --load-from checkpoints/depth_anything_v2_metric_vkitti_vitl.pth \
73
+ --max-depth 80 \
74
+ --img-path <path> --outdir <outdir> [--input-size <size>] [--save-numpy]
75
+ ```
76
+
77
+ ### Project 2D images to point clouds:
78
+
79
+ ```bash
80
+ python depth_to_pointcloud.py \
81
+ --encoder vitl \
82
+ --load-from checkpoints/depth_anything_v2_metric_hypersim_vitl.pth \
83
+ --max-depth 20 \
84
+ --img-path <path> --outdir <outdir>
85
+ ```
86
+
87
+ ### Reproduce training
88
+
89
+ Please first prepare the [Hypersim](https://github.com/apple/ml-hypersim) and [Virtual KITTI 2](https://europe.naverlabs.com/research/computer-vision/proxy-virtual-worlds-vkitti-2/) datasets. Then:
90
+
91
+ ```bash
92
+ bash dist_train.sh
93
+ ```
94
+
95
+
96
+ ## Citation
97
+
98
+ If you find this project useful, please consider citing:
99
+
100
+ ```bibtex
101
+ @article{depth_anything_v2,
102
+ title={Depth Anything V2},
103
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Zhao, Zhen and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
104
+ journal={arXiv:2406.09414},
105
+ year={2024}
106
+ }
107
+
108
+ @inproceedings{depth_anything_v1,
109
+ title={Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data},
110
+ author={Yang, Lihe and Kang, Bingyi and Huang, Zilong and Xu, Xiaogang and Feng, Jiashi and Zhao, Hengshuang},
111
+ booktitle={CVPR},
112
+ year={2024}
113
+ }
114
+ ```
metric_depth/assets/compare_zoedepth.png ADDED

Git LFS Details

  • SHA256: 8044e39ef6cb4aaabea9a81333fa1ff2d3e07448e7f9f43f77f471aba72a12e0
  • Pointer size: 132 Bytes
  • Size of remote file: 9.19 MB
metric_depth/dataset/hypersim.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import h5py
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ from torchvision.transforms import Compose
7
+
8
+ from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
9
+
10
+
11
+ def hypersim_distance_to_depth(npyDistance):
12
+ intWidth, intHeight, fltFocal = 1024, 768, 886.81
13
+
14
+ npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(
15
+ 1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None]
16
+ npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5,
17
+ intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None]
18
+ npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32)
19
+ npyImageplane = np.concatenate(
20
+ [npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2)
21
+
22
+ npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal
23
+ return npyDepth
24
+
25
+
26
+ class Hypersim(Dataset):
27
+ def __init__(self, filelist_path, mode, size=(518, 518)):
28
+
29
+ self.mode = mode
30
+ self.size = size
31
+
32
+ with open(filelist_path, 'r') as f:
33
+ self.filelist = f.read().splitlines()
34
+
35
+ net_w, net_h = size
36
+ self.transform = Compose([
37
+ Resize(
38
+ width=net_w,
39
+ height=net_h,
40
+ resize_target=True if mode == 'train' else False,
41
+ keep_aspect_ratio=True,
42
+ ensure_multiple_of=14,
43
+ resize_method='lower_bound',
44
+ image_interpolation_method=cv2.INTER_CUBIC,
45
+ ),
46
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
47
+ PrepareForNet(),
48
+ ] + ([Crop(size[0])] if self.mode == 'train' else []))
49
+
50
+ def __getitem__(self, item):
51
+ img_path = self.filelist[item].split(' ')[0]
52
+ depth_path = self.filelist[item].split(' ')[1]
53
+
54
+ image = cv2.imread(img_path)
55
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
56
+
57
+ depth_fd = h5py.File(depth_path, "r")
58
+ distance_meters = np.array(depth_fd['dataset'])
59
+ depth = hypersim_distance_to_depth(distance_meters)
60
+
61
+ sample = self.transform({'image': image, 'depth': depth})
62
+
63
+ sample['image'] = torch.from_numpy(sample['image'])
64
+ sample['depth'] = torch.from_numpy(sample['depth'])
65
+
66
+ sample['valid_mask'] = (torch.isnan(sample['depth']) == 0)
67
+ sample['depth'][sample['valid_mask'] == 0] = 0
68
+
69
+ sample['image_path'] = self.filelist[item].split(' ')[0]
70
+
71
+ return sample
72
+
73
+ def __len__(self):
74
+ return len(self.filelist)
metric_depth/dataset/kitti.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torchvision.transforms import Compose
5
+
6
+ from dataset.transform import Resize, NormalizeImage, PrepareForNet
7
+
8
+
9
+ class KITTI(Dataset):
10
+ def __init__(self, filelist_path, mode, size=(518, 518)):
11
+ if mode != 'val':
12
+ raise NotImplementedError
13
+
14
+ self.mode = mode
15
+ self.size = size
16
+
17
+ with open(filelist_path, 'r') as f:
18
+ self.filelist = f.read().splitlines()
19
+
20
+ net_w, net_h = size
21
+ self.transform = Compose([
22
+ Resize(
23
+ width=net_w,
24
+ height=net_h,
25
+ resize_target=True if mode == 'train' else False,
26
+ keep_aspect_ratio=True,
27
+ ensure_multiple_of=14,
28
+ resize_method='lower_bound',
29
+ image_interpolation_method=cv2.INTER_CUBIC,
30
+ ),
31
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
32
+ PrepareForNet(),
33
+ ])
34
+
35
+ def __getitem__(self, item):
36
+ img_path = self.filelist[item].split(' ')[0]
37
+ depth_path = self.filelist[item].split(' ')[1]
38
+
39
+ image = cv2.imread(img_path)
40
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
41
+
42
+ depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED).astype('float32')
43
+
44
+ sample = self.transform({'image': image, 'depth': depth})
45
+
46
+ sample['image'] = torch.from_numpy(sample['image'])
47
+ sample['depth'] = torch.from_numpy(sample['depth'])
48
+ sample['depth'] = sample['depth'] / 256.0 # convert in meters
49
+
50
+ sample['valid_mask'] = sample['depth'] > 0
51
+
52
+ sample['image_path'] = self.filelist[item].split(' ')[0]
53
+
54
+ return sample
55
+
56
+ def __len__(self):
57
+ return len(self.filelist)
metric_depth/dataset/splits/hypersim/val.txt ADDED
The diff for this file is too large to render. See raw diff
 
metric_depth/dataset/splits/kitti/val.txt ADDED
The diff for this file is too large to render. See raw diff
 
metric_depth/dataset/splits/vkitti2/train.txt ADDED
The diff for this file is too large to render. See raw diff
 
metric_depth/dataset/transform.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA):
9
+ """Rezise the sample to ensure the given size. Keeps aspect ratio.
10
+
11
+ Args:
12
+ sample (dict): sample
13
+ size (tuple): image size
14
+
15
+ Returns:
16
+ tuple: new size
17
+ """
18
+ shape = list(sample["disparity"].shape)
19
+
20
+ if shape[0] >= size[0] and shape[1] >= size[1]:
21
+ return sample
22
+
23
+ scale = [0, 0]
24
+ scale[0] = size[0] / shape[0]
25
+ scale[1] = size[1] / shape[1]
26
+
27
+ scale = max(scale)
28
+
29
+ shape[0] = math.ceil(scale * shape[0])
30
+ shape[1] = math.ceil(scale * shape[1])
31
+
32
+ # resize
33
+ sample["image"] = cv2.resize(
34
+ sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method
35
+ )
36
+
37
+ sample["disparity"] = cv2.resize(
38
+ sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST
39
+ )
40
+ sample["mask"] = cv2.resize(
41
+ sample["mask"].astype(np.float32),
42
+ tuple(shape[::-1]),
43
+ interpolation=cv2.INTER_NEAREST,
44
+ )
45
+ sample["mask"] = sample["mask"].astype(bool)
46
+
47
+ return tuple(shape)
48
+
49
+
50
+ class Resize(object):
51
+ """Resize sample to given size (width, height).
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ width,
57
+ height,
58
+ resize_target=True,
59
+ keep_aspect_ratio=False,
60
+ ensure_multiple_of=1,
61
+ resize_method="lower_bound",
62
+ image_interpolation_method=cv2.INTER_AREA,
63
+ ):
64
+ """Init.
65
+
66
+ Args:
67
+ width (int): desired output width
68
+ height (int): desired output height
69
+ resize_target (bool, optional):
70
+ True: Resize the full sample (image, mask, target).
71
+ False: Resize image only.
72
+ Defaults to True.
73
+ keep_aspect_ratio (bool, optional):
74
+ True: Keep the aspect ratio of the input sample.
75
+ Output sample might not have the given width and height, and
76
+ resize behaviour depends on the parameter 'resize_method'.
77
+ Defaults to False.
78
+ ensure_multiple_of (int, optional):
79
+ Output width and height is constrained to be multiple of this parameter.
80
+ Defaults to 1.
81
+ resize_method (str, optional):
82
+ "lower_bound": Output will be at least as large as the given size.
83
+ "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
84
+ "minimal": Scale as least as possible. (Output size might be smaller than given size.)
85
+ Defaults to "lower_bound".
86
+ """
87
+ self.__width = width
88
+ self.__height = height
89
+
90
+ self.__resize_target = resize_target
91
+ self.__keep_aspect_ratio = keep_aspect_ratio
92
+ self.__multiple_of = ensure_multiple_of
93
+ self.__resize_method = resize_method
94
+ self.__image_interpolation_method = image_interpolation_method
95
+
96
+ def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
97
+ y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
98
+
99
+ if max_val is not None and y > max_val:
100
+ y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
101
+
102
+ if y < min_val:
103
+ y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
104
+
105
+ return y
106
+
107
+ def get_size(self, width, height):
108
+ # determine new height and width
109
+ scale_height = self.__height / height
110
+ scale_width = self.__width / width
111
+
112
+ if self.__keep_aspect_ratio:
113
+ if self.__resize_method == "lower_bound":
114
+ # scale such that output size is lower bound
115
+ if scale_width > scale_height:
116
+ # fit width
117
+ scale_height = scale_width
118
+ else:
119
+ # fit height
120
+ scale_width = scale_height
121
+ elif self.__resize_method == "upper_bound":
122
+ # scale such that output size is upper bound
123
+ if scale_width < scale_height:
124
+ # fit width
125
+ scale_height = scale_width
126
+ else:
127
+ # fit height
128
+ scale_width = scale_height
129
+ elif self.__resize_method == "minimal":
130
+ # scale as least as possbile
131
+ if abs(1 - scale_width) < abs(1 - scale_height):
132
+ # fit width
133
+ scale_height = scale_width
134
+ else:
135
+ # fit height
136
+ scale_width = scale_height
137
+ else:
138
+ raise ValueError(
139
+ f"resize_method {self.__resize_method} not implemented"
140
+ )
141
+
142
+ if self.__resize_method == "lower_bound":
143
+ new_height = self.constrain_to_multiple_of(
144
+ scale_height * height, min_val=self.__height
145
+ )
146
+ new_width = self.constrain_to_multiple_of(
147
+ scale_width * width, min_val=self.__width
148
+ )
149
+ elif self.__resize_method == "upper_bound":
150
+ new_height = self.constrain_to_multiple_of(
151
+ scale_height * height, max_val=self.__height
152
+ )
153
+ new_width = self.constrain_to_multiple_of(
154
+ scale_width * width, max_val=self.__width
155
+ )
156
+ elif self.__resize_method == "minimal":
157
+ new_height = self.constrain_to_multiple_of(scale_height * height)
158
+ new_width = self.constrain_to_multiple_of(scale_width * width)
159
+ else:
160
+ raise ValueError(f"resize_method {self.__resize_method} not implemented")
161
+
162
+ return (new_width, new_height)
163
+
164
+ def __call__(self, sample):
165
+ width, height = self.get_size(
166
+ sample["image"].shape[1], sample["image"].shape[0]
167
+ )
168
+
169
+ # resize sample
170
+ sample["image"] = cv2.resize(
171
+ sample["image"],
172
+ (width, height),
173
+ interpolation=self.__image_interpolation_method,
174
+ )
175
+
176
+ if self.__resize_target:
177
+ if "disparity" in sample:
178
+ sample["disparity"] = cv2.resize(
179
+ sample["disparity"],
180
+ (width, height),
181
+ interpolation=cv2.INTER_NEAREST,
182
+ )
183
+
184
+ if "depth" in sample:
185
+ sample["depth"] = cv2.resize(
186
+ sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST
187
+ )
188
+
189
+ if "semseg_mask" in sample:
190
+ # sample["semseg_mask"] = cv2.resize(
191
+ # sample["semseg_mask"], (width, height), interpolation=cv2.INTER_NEAREST
192
+ # )
193
+ sample["semseg_mask"] = F.interpolate(torch.from_numpy(sample["semseg_mask"]).float()[None, None, ...], (height, width), mode='nearest').numpy()[0, 0]
194
+
195
+ if "mask" in sample:
196
+ sample["mask"] = cv2.resize(
197
+ sample["mask"].astype(np.float32),
198
+ (width, height),
199
+ interpolation=cv2.INTER_NEAREST,
200
+ )
201
+ # sample["mask"] = sample["mask"].astype(bool)
202
+
203
+ # print(sample['image'].shape, sample['depth'].shape)
204
+ return sample
205
+
206
+
207
+ class NormalizeImage(object):
208
+ """Normlize image by given mean and std.
209
+ """
210
+
211
+ def __init__(self, mean, std):
212
+ self.__mean = mean
213
+ self.__std = std
214
+
215
+ def __call__(self, sample):
216
+ sample["image"] = (sample["image"] - self.__mean) / self.__std
217
+
218
+ return sample
219
+
220
+
221
+ class PrepareForNet(object):
222
+ """Prepare sample for usage as network input.
223
+ """
224
+
225
+ def __init__(self):
226
+ pass
227
+
228
+ def __call__(self, sample):
229
+ image = np.transpose(sample["image"], (2, 0, 1))
230
+ sample["image"] = np.ascontiguousarray(image).astype(np.float32)
231
+
232
+ if "mask" in sample:
233
+ sample["mask"] = sample["mask"].astype(np.float32)
234
+ sample["mask"] = np.ascontiguousarray(sample["mask"])
235
+
236
+ if "depth" in sample:
237
+ depth = sample["depth"].astype(np.float32)
238
+ sample["depth"] = np.ascontiguousarray(depth)
239
+
240
+ if "semseg_mask" in sample:
241
+ sample["semseg_mask"] = sample["semseg_mask"].astype(np.float32)
242
+ sample["semseg_mask"] = np.ascontiguousarray(sample["semseg_mask"])
243
+
244
+ return sample
245
+
246
+
247
+ class Crop(object):
248
+ """Crop sample for batch-wise training. Image is of shape CxHxW
249
+ """
250
+
251
+ def __init__(self, size):
252
+ if isinstance(size, int):
253
+ self.size = (size, size)
254
+ else:
255
+ self.size = size
256
+
257
+ def __call__(self, sample):
258
+ h, w = sample['image'].shape[-2:]
259
+ assert h >= self.size[0] and w >= self.size[1], 'Wrong size'
260
+
261
+ h_start = np.random.randint(0, h - self.size[0] + 1)
262
+ w_start = np.random.randint(0, w - self.size[1] + 1)
263
+ h_end = h_start + self.size[0]
264
+ w_end = w_start + self.size[1]
265
+
266
+ sample['image'] = sample['image'][:, h_start: h_end, w_start: w_end]
267
+
268
+ if "depth" in sample:
269
+ sample["depth"] = sample["depth"][h_start: h_end, w_start: w_end]
270
+
271
+ if "mask" in sample:
272
+ sample["mask"] = sample["mask"][h_start: h_end, w_start: w_end]
273
+
274
+ if "semseg_mask" in sample:
275
+ sample["semseg_mask"] = sample["semseg_mask"][h_start: h_end, w_start: w_end]
276
+
277
+ return sample
metric_depth/dataset/vkitti2.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torchvision.transforms import Compose
5
+
6
+ from dataset.transform import Resize, NormalizeImage, PrepareForNet, Crop
7
+
8
+
9
+ class VKITTI2(Dataset):
10
+ def __init__(self, filelist_path, mode, size=(518, 518)):
11
+
12
+ self.mode = mode
13
+ self.size = size
14
+
15
+ with open(filelist_path, 'r') as f:
16
+ self.filelist = f.read().splitlines()
17
+
18
+ net_w, net_h = size
19
+ self.transform = Compose([
20
+ Resize(
21
+ width=net_w,
22
+ height=net_h,
23
+ resize_target=True if mode == 'train' else False,
24
+ keep_aspect_ratio=True,
25
+ ensure_multiple_of=14,
26
+ resize_method='lower_bound',
27
+ image_interpolation_method=cv2.INTER_CUBIC,
28
+ ),
29
+ NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
30
+ PrepareForNet(),
31
+ ] + ([Crop(size[0])] if self.mode == 'train' else []))
32
+
33
+ def __getitem__(self, item):
34
+ img_path = self.filelist[item].split(' ')[0]
35
+ depth_path = self.filelist[item].split(' ')[1]
36
+
37
+ image = cv2.imread(img_path)
38
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
39
+
40
+ depth = cv2.imread(depth_path, cv2.IMREAD_ANYCOLOR | cv2.IMREAD_ANYDEPTH) / 100.0 # cm to m
41
+
42
+ sample = self.transform({'image': image, 'depth': depth})
43
+
44
+ sample['image'] = torch.from_numpy(sample['image'])
45
+ sample['depth'] = torch.from_numpy(sample['depth'])
46
+
47
+ sample['valid_mask'] = (sample['depth'] <= 80)
48
+
49
+ sample['image_path'] = self.filelist[item].split(' ')[0]
50
+
51
+ return sample
52
+
53
+ def __len__(self):
54
+ return len(self.filelist)
metric_depth/depth_anything_v2/dinov2.py ADDED
@@ -0,0 +1,415 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ # we add a small number to avoid floating point error in the interpolation
192
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
193
+ # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
194
+ w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
195
+ # w0, h0 = w0 + 0.1, h0 + 0.1
196
+
197
+ sqrt_N = math.sqrt(N)
198
+ sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
199
+ patch_pos_embed = nn.functional.interpolate(
200
+ patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
201
+ scale_factor=(sx, sy),
202
+ # (int(w0), int(h0)), # to solve the upsampling shape issue
203
+ mode="bicubic",
204
+ antialias=self.interpolate_antialias
205
+ )
206
+
207
+ assert int(w0) == patch_pos_embed.shape[-2]
208
+ assert int(h0) == patch_pos_embed.shape[-1]
209
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
210
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
211
+
212
+ def prepare_tokens_with_masks(self, x, masks=None):
213
+ B, nc, w, h = x.shape
214
+ x = self.patch_embed(x)
215
+ if masks is not None:
216
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
217
+
218
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
219
+ x = x + self.interpolate_pos_encoding(x, w, h)
220
+
221
+ if self.register_tokens is not None:
222
+ x = torch.cat(
223
+ (
224
+ x[:, :1],
225
+ self.register_tokens.expand(x.shape[0], -1, -1),
226
+ x[:, 1:],
227
+ ),
228
+ dim=1,
229
+ )
230
+
231
+ return x
232
+
233
+ def forward_features_list(self, x_list, masks_list):
234
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
235
+ for blk in self.blocks:
236
+ x = blk(x)
237
+
238
+ all_x = x
239
+ output = []
240
+ for x, masks in zip(all_x, masks_list):
241
+ x_norm = self.norm(x)
242
+ output.append(
243
+ {
244
+ "x_norm_clstoken": x_norm[:, 0],
245
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
246
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
247
+ "x_prenorm": x,
248
+ "masks": masks,
249
+ }
250
+ )
251
+ return output
252
+
253
+ def forward_features(self, x, masks=None):
254
+ if isinstance(x, list):
255
+ return self.forward_features_list(x, masks)
256
+
257
+ x = self.prepare_tokens_with_masks(x, masks)
258
+
259
+ for blk in self.blocks:
260
+ x = blk(x)
261
+
262
+ x_norm = self.norm(x)
263
+ return {
264
+ "x_norm_clstoken": x_norm[:, 0],
265
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
266
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
267
+ "x_prenorm": x,
268
+ "masks": masks,
269
+ }
270
+
271
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
272
+ x = self.prepare_tokens_with_masks(x)
273
+ # If n is an int, take the n last blocks. If it's a list, take them
274
+ output, total_block_len = [], len(self.blocks)
275
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
276
+ for i, blk in enumerate(self.blocks):
277
+ x = blk(x)
278
+ if i in blocks_to_take:
279
+ output.append(x)
280
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
281
+ return output
282
+
283
+ def _get_intermediate_layers_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
286
+ # If n is an int, take the n last blocks. If it's a list, take them
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for block_chunk in self.blocks:
289
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
290
+ x = blk(x)
291
+ if i in blocks_to_take:
292
+ output.append(x)
293
+ i += 1
294
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
295
+ return output
296
+
297
+ def get_intermediate_layers(
298
+ self,
299
+ x: torch.Tensor,
300
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
301
+ reshape: bool = False,
302
+ return_class_token: bool = False,
303
+ norm=True
304
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
305
+ if self.chunked_blocks:
306
+ outputs = self._get_intermediate_layers_chunked(x, n)
307
+ else:
308
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
309
+ if norm:
310
+ outputs = [self.norm(out) for out in outputs]
311
+ class_tokens = [out[:, 0] for out in outputs]
312
+ outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
313
+ if reshape:
314
+ B, _, w, h = x.shape
315
+ outputs = [
316
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
317
+ for out in outputs
318
+ ]
319
+ if return_class_token:
320
+ return tuple(zip(outputs, class_tokens))
321
+ return tuple(outputs)
322
+
323
+ def forward(self, *args, is_training=False, **kwargs):
324
+ ret = self.forward_features(*args, **kwargs)
325
+ if is_training:
326
+ return ret
327
+ else:
328
+ return self.head(ret["x_norm_clstoken"])
329
+
330
+
331
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
332
+ """ViT weight initialization, original timm impl (for reproducibility)"""
333
+ if isinstance(module, nn.Linear):
334
+ trunc_normal_(module.weight, std=0.02)
335
+ if module.bias is not None:
336
+ nn.init.zeros_(module.bias)
337
+
338
+
339
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
340
+ model = DinoVisionTransformer(
341
+ patch_size=patch_size,
342
+ embed_dim=384,
343
+ depth=12,
344
+ num_heads=6,
345
+ mlp_ratio=4,
346
+ block_fn=partial(Block, attn_class=MemEffAttention),
347
+ num_register_tokens=num_register_tokens,
348
+ **kwargs,
349
+ )
350
+ return model
351
+
352
+
353
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
354
+ model = DinoVisionTransformer(
355
+ patch_size=patch_size,
356
+ embed_dim=768,
357
+ depth=12,
358
+ num_heads=12,
359
+ mlp_ratio=4,
360
+ block_fn=partial(Block, attn_class=MemEffAttention),
361
+ num_register_tokens=num_register_tokens,
362
+ **kwargs,
363
+ )
364
+ return model
365
+
366
+
367
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
368
+ model = DinoVisionTransformer(
369
+ patch_size=patch_size,
370
+ embed_dim=1024,
371
+ depth=24,
372
+ num_heads=16,
373
+ mlp_ratio=4,
374
+ block_fn=partial(Block, attn_class=MemEffAttention),
375
+ num_register_tokens=num_register_tokens,
376
+ **kwargs,
377
+ )
378
+ return model
379
+
380
+
381
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
382
+ """
383
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
384
+ """
385
+ model = DinoVisionTransformer(
386
+ patch_size=patch_size,
387
+ embed_dim=1536,
388
+ depth=40,
389
+ num_heads=24,
390
+ mlp_ratio=4,
391
+ block_fn=partial(Block, attn_class=MemEffAttention),
392
+ num_register_tokens=num_register_tokens,
393
+ **kwargs,
394
+ )
395
+ return model
396
+
397
+
398
+ def DINOv2(model_name):
399
+ model_zoo = {
400
+ "vits": vit_small,
401
+ "vitb": vit_base,
402
+ "vitl": vit_large,
403
+ "vitg": vit_giant2
404
+ }
405
+
406
+ return model_zoo[model_name](
407
+ img_size=518,
408
+ patch_size=14,
409
+ init_values=1.0,
410
+ ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
411
+ block_chunks=0,
412
+ num_register_tokens=0,
413
+ interpolate_antialias=False,
414
+ interpolate_offset=0.1
415
+ )
metric_depth/depth_anything_v2/dinov2_layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
metric_depth/depth_anything_v2/dinov2_layers/attention.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+
13
+ from torch import Tensor
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ try:
21
+ from xformers.ops import memory_efficient_attention, unbind, fmha
22
+
23
+ XFORMERS_AVAILABLE = True
24
+ except ImportError:
25
+ logger.warning("xFormers not available")
26
+ XFORMERS_AVAILABLE = False
27
+
28
+
29
+ class Attention(nn.Module):
30
+ def __init__(
31
+ self,
32
+ dim: int,
33
+ num_heads: int = 8,
34
+ qkv_bias: bool = False,
35
+ proj_bias: bool = True,
36
+ attn_drop: float = 0.0,
37
+ proj_drop: float = 0.0,
38
+ ) -> None:
39
+ super().__init__()
40
+ self.num_heads = num_heads
41
+ head_dim = dim // num_heads
42
+ self.scale = head_dim**-0.5
43
+
44
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
45
+ self.attn_drop = nn.Dropout(attn_drop)
46
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
47
+ self.proj_drop = nn.Dropout(proj_drop)
48
+
49
+ def forward(self, x: Tensor) -> Tensor:
50
+ B, N, C = x.shape
51
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
52
+
53
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
54
+ attn = q @ k.transpose(-2, -1)
55
+
56
+ attn = attn.softmax(dim=-1)
57
+ attn = self.attn_drop(attn)
58
+
59
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
60
+ x = self.proj(x)
61
+ x = self.proj_drop(x)
62
+ return x
63
+
64
+
65
+ class MemEffAttention(Attention):
66
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
67
+ if not XFORMERS_AVAILABLE:
68
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
69
+ return super().forward(x)
70
+
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
73
+
74
+ q, k, v = unbind(qkv, 2)
75
+
76
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
77
+ x = x.reshape([B, N, C])
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+
metric_depth/depth_anything_v2/dinov2_layers/block.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ import logging
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+
14
+ import torch
15
+ from torch import nn, Tensor
16
+
17
+ from .attention import Attention, MemEffAttention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ try:
27
+ from xformers.ops import fmha
28
+ from xformers.ops import scaled_index_add, index_select_cat
29
+
30
+ XFORMERS_AVAILABLE = True
31
+ except ImportError:
32
+ logger.warning("xFormers not available")
33
+ XFORMERS_AVAILABLE = False
34
+
35
+
36
+ class Block(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int,
41
+ mlp_ratio: float = 4.0,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ ffn_bias: bool = True,
45
+ drop: float = 0.0,
46
+ attn_drop: float = 0.0,
47
+ init_values=None,
48
+ drop_path: float = 0.0,
49
+ act_layer: Callable[..., nn.Module] = nn.GELU,
50
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
51
+ attn_class: Callable[..., nn.Module] = Attention,
52
+ ffn_layer: Callable[..., nn.Module] = Mlp,
53
+ ) -> None:
54
+ super().__init__()
55
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
56
+ self.norm1 = norm_layer(dim)
57
+ self.attn = attn_class(
58
+ dim,
59
+ num_heads=num_heads,
60
+ qkv_bias=qkv_bias,
61
+ proj_bias=proj_bias,
62
+ attn_drop=attn_drop,
63
+ proj_drop=drop,
64
+ )
65
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
66
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
67
+
68
+ self.norm2 = norm_layer(dim)
69
+ mlp_hidden_dim = int(dim * mlp_ratio)
70
+ self.mlp = ffn_layer(
71
+ in_features=dim,
72
+ hidden_features=mlp_hidden_dim,
73
+ act_layer=act_layer,
74
+ drop=drop,
75
+ bias=ffn_bias,
76
+ )
77
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
78
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
79
+
80
+ self.sample_drop_ratio = drop_path
81
+
82
+ def forward(self, x: Tensor) -> Tensor:
83
+ def attn_residual_func(x: Tensor) -> Tensor:
84
+ return self.ls1(self.attn(self.norm1(x)))
85
+
86
+ def ffn_residual_func(x: Tensor) -> Tensor:
87
+ return self.ls2(self.mlp(self.norm2(x)))
88
+
89
+ if self.training and self.sample_drop_ratio > 0.1:
90
+ # the overhead is compensated only for a drop path rate larger than 0.1
91
+ x = drop_add_residual_stochastic_depth(
92
+ x,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ ) -> Tensor:
115
+ # 1) extract subset using permutation
116
+ b, n, d = x.shape
117
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
118
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
119
+ x_subset = x[brange]
120
+
121
+ # 2) apply residual_func to get residual
122
+ residual = residual_func(x_subset)
123
+
124
+ x_flat = x.flatten(1)
125
+ residual = residual.flatten(1)
126
+
127
+ residual_scale_factor = b / sample_subset_size
128
+
129
+ # 3) add the residual
130
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
131
+ return x_plus_residual.view_as(x)
132
+
133
+
134
+ def get_branges_scales(x, sample_drop_ratio=0.0):
135
+ b, n, d = x.shape
136
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
137
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
138
+ residual_scale_factor = b / sample_subset_size
139
+ return brange, residual_scale_factor
140
+
141
+
142
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
143
+ if scaling_vector is None:
144
+ x_flat = x.flatten(1)
145
+ residual = residual.flatten(1)
146
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
147
+ else:
148
+ x_plus_residual = scaled_index_add(
149
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
150
+ )
151
+ return x_plus_residual
152
+
153
+
154
+ attn_bias_cache: Dict[Tuple, Any] = {}
155
+
156
+
157
+ def get_attn_bias_and_cat(x_list, branges=None):
158
+ """
159
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
160
+ """
161
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
162
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
163
+ if all_shapes not in attn_bias_cache.keys():
164
+ seqlens = []
165
+ for b, x in zip(batch_sizes, x_list):
166
+ for _ in range(b):
167
+ seqlens.append(x.shape[1])
168
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
169
+ attn_bias._batch_sizes = batch_sizes
170
+ attn_bias_cache[all_shapes] = attn_bias
171
+
172
+ if branges is not None:
173
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
174
+ else:
175
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
176
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
177
+
178
+ return attn_bias_cache[all_shapes], cat_tensors
179
+
180
+
181
+ def drop_add_residual_stochastic_depth_list(
182
+ x_list: List[Tensor],
183
+ residual_func: Callable[[Tensor, Any], Tensor],
184
+ sample_drop_ratio: float = 0.0,
185
+ scaling_vector=None,
186
+ ) -> Tensor:
187
+ # 1) generate random set of indices for dropping samples in the batch
188
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
189
+ branges = [s[0] for s in branges_scales]
190
+ residual_scale_factors = [s[1] for s in branges_scales]
191
+
192
+ # 2) get attention bias and index+concat the tensors
193
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
194
+
195
+ # 3) apply residual_func to get residual, and split the result
196
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
197
+
198
+ outputs = []
199
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
200
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
201
+ return outputs
202
+
203
+
204
+ class NestedTensorBlock(Block):
205
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
206
+ """
207
+ x_list contains a list of tensors to nest together and run
208
+ """
209
+ assert isinstance(self.attn, MemEffAttention)
210
+
211
+ if self.training and self.sample_drop_ratio > 0.0:
212
+
213
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
214
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
215
+
216
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
217
+ return self.mlp(self.norm2(x))
218
+
219
+ x_list = drop_add_residual_stochastic_depth_list(
220
+ x_list,
221
+ residual_func=attn_residual_func,
222
+ sample_drop_ratio=self.sample_drop_ratio,
223
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
224
+ )
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=ffn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ return x_list
232
+ else:
233
+
234
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
235
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
236
+
237
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
238
+ return self.ls2(self.mlp(self.norm2(x)))
239
+
240
+ attn_bias, x = get_attn_bias_and_cat(x_list)
241
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
242
+ x = x + ffn_residual_func(x)
243
+ return attn_bias.split(x)
244
+
245
+ def forward(self, x_or_x_list):
246
+ if isinstance(x_or_x_list, Tensor):
247
+ return super().forward(x_or_x_list)
248
+ elif isinstance(x_or_x_list, list):
249
+ assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
250
+ return self.forward_nested(x_or_x_list)
251
+ else:
252
+ raise AssertionError