import glob import gradio as gr import matplotlib import numpy as np from PIL import Image import torch import tempfile from gradio_imageslider import ImageSlider import plotly.graph_objects as go import plotly.express as px import open3d as o3d from depth_anything_v2.dpt import DepthAnythingV2 import os import gdown # Define path and file ID checkpoint_dir = "checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" # Download if not already present if not os.path.exists(model_file): print("Downloading model from Google Drive...") gdown.download(gdrive_url, model_file, quiet=False) css = """ #img-display-container { max-height: 100vh; } #img-display-input { max-height: 80vh; } #img-display-output { max-height: 80vh; } #download { height: 62px; } h1 { text-align: center; font-size: 3rem; font-weight: bold; margin: 2rem 0; color: #2c3e50; } """ DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' model_configs = { 'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, 'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, 'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, 'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} } encoder = 'vitl' model = DepthAnythingV2(**model_configs[encoder]) state_dict = torch.load(f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu") model.load_state_dict(state_dict) model = model.to(DEVICE).eval() title = "Depth Estimation, 3D Visualization" description = """Official demo for **Depth Estimation, 3D Visualization**.""" def predict_depth(image): return model.infer_image(image) def calculate_max_points(image): """Calculate maximum points based on image dimensions (3x pixel count)""" if image is None: return 10000 # Default value h, w = image.shape[:2] max_points = h * w * 3 # Ensure minimum and reasonable maximum values return max(1000, min(max_points, 1000000)) def update_slider_on_image_upload(image): """Update the points slider when an image is uploaded""" max_points = calculate_max_points(image) default_value = min(10000, max_points // 10) # 10% of max points as default return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, label=f"Number of 3D points (max: {max_points:,})") def create_3d_depth_visualization(image, depth_map, max_points=10000): """Create an interactive 3D visualization of the depth map""" h, w = depth_map.shape # Downsample to avoid too many points for performance step = max(1, int(np.sqrt(h * w / max_points))) # Create coordinate grids y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] depth_values = depth_map[::step, ::step] # Flatten arrays x_flat = x_coords.flatten() y_flat = y_coords.flatten() z_flat = depth_values.flatten() # Get corresponding image colors image_colors = image[::step, ::step, :] colors_flat = image_colors.reshape(-1, 3) # Create 3D scatter plot fig = go.Figure(data=[go.Scatter3d( x=x_flat, y=y_flat, z=z_flat, mode='markers', marker=dict( size=2, color=colors_flat, opacity=0.8 ), hovertemplate='Position: (%{x:.0f}, %{y:.0f})
' + 'Depth: %{z:.2f}
' + '' )]) fig.update_layout( title="3D Depth Visualization (Hover to see depth values)", scene=dict( xaxis_title="X (pixels)", yaxis_title="Y (pixels)", zaxis_title="Depth", camera=dict( eye=dict(x=1.5, y=1.5, z=1.5) ) ), width=600, height=500 ) return fig def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000): """Create a point cloud from depth map using camera intrinsics""" h, w = depth_map.shape # Downsample to avoid too many points for performance step = max(1, int(np.sqrt(h * w / max_points))) # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) x_cam = (x_coords - w / 2) / focal_length_x y_cam = (y_coords - h / 2) / focal_length_y # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) # Get corresponding image colors image_colors = image[::step, ::step, :] colors = image_colors.reshape(-1, 3) / 255.0 # Create Open3D point cloud pcd = o3d.geometry.PointCloud() pcd.points = o3d.utility.Vector3dVector(points) pcd.colors = o3d.utility.Vector3dVector(colors) return pcd def create_enhanced_3d_visualization(image, depth_map, max_points=10000): """Create an enhanced 3D visualization using proper camera projection""" h, w = depth_map.shape # Downsample to avoid too many points for performance step = max(1, int(np.sqrt(h * w / max_points))) # Create mesh grid for camera coordinates y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] # Convert to camera coordinates (normalized by focal length) focal_length = 470.4 # Default focal length x_cam = (x_coords - w / 2) / focal_length y_cam = (y_coords - h / 2) / focal_length # Get depth values depth_values = depth_map[::step, ::step] # Calculate 3D points: (x_cam * depth, y_cam * depth, depth) x_3d = x_cam * depth_values y_3d = y_cam * depth_values z_3d = depth_values # Flatten arrays x_flat = x_3d.flatten() y_flat = y_3d.flatten() z_flat = z_3d.flatten() # Get corresponding image colors image_colors = image[::step, ::step, :] colors_flat = image_colors.reshape(-1, 3) # Create 3D scatter plot with proper camera projection fig = go.Figure(data=[go.Scatter3d( x=x_flat, y=y_flat, z=z_flat, mode='markers', marker=dict( size=1.5, color=colors_flat, opacity=0.9 ), hovertemplate='3D Position: (%{x:.3f}, %{y:.3f}, %{z:.3f})
' + 'Depth: %{z:.2f}
' + '' )]) fig.update_layout( title="3D Point Cloud Visualization (Camera Projection)", scene=dict( xaxis_title="X (meters)", yaxis_title="Y (meters)", zaxis_title="Z (meters)", camera=dict( eye=dict(x=2.0, y=2.0, z=2.0), center=dict(x=0, y=0, z=0), up=dict(x=0, y=0, z=1) ), aspectmode='data' ), width=700, height=600 ) return fig with gr.Blocks(css=css) as demo: gr.HTML(f"

{title}

") gr.Markdown(description) gr.Markdown("### Depth Prediction demo") with gr.Row(): input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') with gr.Row(): submit = gr.Button(value="Compute Depth", variant="primary") points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, label="Number of 3D points (upload image to update max)") with gr.Row(): focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length X (pixels)") focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, label="Focal Length Y (pixels)") with gr.Row(): gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") # 3D Visualization gr.Markdown("### 3D Point Cloud Visualization") gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") depth_3d_plot = gr.Plot(label="3D Point Cloud") cmap = matplotlib.colormaps.get_cmap('Spectral_r') def on_submit(image, num_points, focal_x, focal_y): original_image = image.copy() h, w = image.shape[:2] depth = predict_depth(image[:, :, ::-1]) raw_depth = Image.fromarray(depth.astype('uint16')) tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) raw_depth.save(tmp_raw_depth.name) depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 depth = depth.astype(np.uint8) colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) gray_depth = Image.fromarray(depth) tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) gray_depth.save(tmp_gray_depth.name) # Create point cloud pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points) tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) o3d.io.write_point_cloud(tmp_pointcloud.name, pcd) # Create enhanced 3D visualization depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points) return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] # Update slider when image is uploaded input_image.change( fn=update_slider_on_image_upload, inputs=[input_image], outputs=[points_slider] ) submit.click(on_submit, inputs=[input_image, points_slider, focal_length_x, focal_length_y], outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot]) if __name__ == '__main__': demo.queue().launch()