Spaces:
Running
Running
import glob | |
import gradio as gr | |
import matplotlib | |
import numpy as np | |
from PIL import Image | |
import torch | |
import tempfile | |
from gradio_imageslider import ImageSlider | |
import plotly.graph_objects as go | |
import plotly.express as px | |
import open3d as o3d | |
from depth_anything_v2.dpt import DepthAnythingV2 | |
import os | |
import gdown | |
# Define path and file ID | |
checkpoint_dir = "checkpoints" | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
model_file = os.path.join(checkpoint_dir, "depth_anything_v2_vitl.pth") | |
gdrive_url = "https://drive.google.com/uc?id=141Mhq2jonkUBcVBnNqNSeyIZYtH5l4K5" | |
# Download if not already present | |
if not os.path.exists(model_file): | |
print("Downloading model from Google Drive...") | |
gdown.download(gdrive_url, model_file, quiet=False) | |
css = """ | |
#img-display-container { | |
max-height: 100vh; | |
} | |
#img-display-input { | |
max-height: 80vh; | |
} | |
#img-display-output { | |
max-height: 80vh; | |
} | |
#download { | |
height: 62px; | |
} | |
h1 { | |
text-align: center; | |
font-size: 3rem; | |
font-weight: bold; | |
margin: 2rem 0; | |
color: #2c3e50; | |
} | |
""" | |
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' | |
model_configs = { | |
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]}, | |
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]}, | |
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]}, | |
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]} | |
} | |
encoder = 'vitl' | |
model = DepthAnythingV2(**model_configs[encoder]) | |
state_dict = torch.load(f'/home/user/app/checkpoints/depth_anything_v2_{encoder}.pth', map_location="cpu") | |
model.load_state_dict(state_dict) | |
model = model.to(DEVICE).eval() | |
title = "Depth Estimation, 3D Visualization" | |
description = """Official demo for **Depth Estimation, 3D Visualization**.""" | |
def predict_depth(image): | |
return model.infer_image(image) | |
def calculate_max_points(image): | |
"""Calculate maximum points based on image dimensions (3x pixel count)""" | |
if image is None: | |
return 10000 # Default value | |
h, w = image.shape[:2] | |
max_points = h * w * 3 | |
# Ensure minimum and reasonable maximum values | |
return max(1000, min(max_points, 1000000)) | |
def update_slider_on_image_upload(image): | |
"""Update the points slider when an image is uploaded""" | |
max_points = calculate_max_points(image) | |
default_value = min(10000, max_points // 10) # 10% of max points as default | |
return gr.Slider(minimum=1000, maximum=max_points, value=default_value, step=1000, | |
label=f"Number of 3D points (max: {max_points:,})") | |
def create_3d_depth_visualization(image, depth_map, max_points=10000): | |
"""Create an interactive 3D visualization of the depth map""" | |
h, w = depth_map.shape | |
# Downsample to avoid too many points for performance | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
# Create coordinate grids | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
depth_values = depth_map[::step, ::step] | |
# Flatten arrays | |
x_flat = x_coords.flatten() | |
y_flat = y_coords.flatten() | |
z_flat = depth_values.flatten() | |
# Get corresponding image colors | |
image_colors = image[::step, ::step, :] | |
colors_flat = image_colors.reshape(-1, 3) | |
# Create 3D scatter plot | |
fig = go.Figure(data=[go.Scatter3d( | |
x=x_flat, | |
y=y_flat, | |
z=z_flat, | |
mode='markers', | |
marker=dict( | |
size=2, | |
color=colors_flat, | |
opacity=0.8 | |
), | |
hovertemplate='<b>Position:</b> (%{x:.0f}, %{y:.0f})<br>' + | |
'<b>Depth:</b> %{z:.2f}<br>' + | |
'<extra></extra>' | |
)]) | |
fig.update_layout( | |
title="3D Depth Visualization (Hover to see depth values)", | |
scene=dict( | |
xaxis_title="X (pixels)", | |
yaxis_title="Y (pixels)", | |
zaxis_title="Depth", | |
camera=dict( | |
eye=dict(x=1.5, y=1.5, z=1.5) | |
) | |
), | |
width=600, | |
height=500 | |
) | |
return fig | |
def create_point_cloud(image, depth_map, focal_length_x=470.4, focal_length_y=470.4, max_points=100000): | |
"""Create a point cloud from depth map using camera intrinsics""" | |
h, w = depth_map.shape | |
# Downsample to avoid too many points for performance | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
# Create mesh grid for camera coordinates | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
# Convert to camera coordinates (normalized by focal length) | |
x_cam = (x_coords - w / 2) / focal_length_x | |
y_cam = (y_coords - h / 2) / focal_length_y | |
# Get depth values | |
depth_values = depth_map[::step, ::step] | |
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
# Flatten arrays | |
points = np.stack([x_3d.flatten(), y_3d.flatten(), z_3d.flatten()], axis=1) | |
# Get corresponding image colors | |
image_colors = image[::step, ::step, :] | |
colors = image_colors.reshape(-1, 3) / 255.0 | |
# Create Open3D point cloud | |
pcd = o3d.geometry.PointCloud() | |
pcd.points = o3d.utility.Vector3dVector(points) | |
pcd.colors = o3d.utility.Vector3dVector(colors) | |
return pcd | |
def create_enhanced_3d_visualization(image, depth_map, max_points=10000): | |
"""Create an enhanced 3D visualization using proper camera projection""" | |
h, w = depth_map.shape | |
# Downsample to avoid too many points for performance | |
step = max(1, int(np.sqrt(h * w / max_points))) | |
# Create mesh grid for camera coordinates | |
y_coords, x_coords = np.mgrid[0:h:step, 0:w:step] | |
# Convert to camera coordinates (normalized by focal length) | |
focal_length = 470.4 # Default focal length | |
x_cam = (x_coords - w / 2) / focal_length | |
y_cam = (y_coords - h / 2) / focal_length | |
# Get depth values | |
depth_values = depth_map[::step, ::step] | |
# Calculate 3D points: (x_cam * depth, y_cam * depth, depth) | |
x_3d = x_cam * depth_values | |
y_3d = y_cam * depth_values | |
z_3d = depth_values | |
# Flatten arrays | |
x_flat = x_3d.flatten() | |
y_flat = y_3d.flatten() | |
z_flat = z_3d.flatten() | |
# Get corresponding image colors | |
image_colors = image[::step, ::step, :] | |
colors_flat = image_colors.reshape(-1, 3) | |
# Create 3D scatter plot with proper camera projection | |
fig = go.Figure(data=[go.Scatter3d( | |
x=x_flat, | |
y=y_flat, | |
z=z_flat, | |
mode='markers', | |
marker=dict( | |
size=1.5, | |
color=colors_flat, | |
opacity=0.9 | |
), | |
hovertemplate='<b>3D Position:</b> (%{x:.3f}, %{y:.3f}, %{z:.3f})<br>' + | |
'<b>Depth:</b> %{z:.2f}<br>' + | |
'<extra></extra>' | |
)]) | |
fig.update_layout( | |
title="3D Point Cloud Visualization (Camera Projection)", | |
scene=dict( | |
xaxis_title="X (meters)", | |
yaxis_title="Y (meters)", | |
zaxis_title="Z (meters)", | |
camera=dict( | |
eye=dict(x=2.0, y=2.0, z=2.0), | |
center=dict(x=0, y=0, z=0), | |
up=dict(x=0, y=0, z=1) | |
), | |
aspectmode='data' | |
), | |
width=700, | |
height=600 | |
) | |
return fig | |
with gr.Blocks(css=css) as demo: | |
gr.HTML(f"<h1>{title}</h1>") | |
gr.Markdown(description) | |
gr.Markdown("### Depth Prediction demo") | |
with gr.Row(): | |
input_image = gr.Image(label="Input Image", type='numpy', elem_id='img-display-input') | |
depth_image_slider = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output') | |
with gr.Row(): | |
submit = gr.Button(value="Compute Depth", variant="primary") | |
points_slider = gr.Slider(minimum=1000, maximum=10000, value=10000, step=1000, | |
label="Number of 3D points (upload image to update max)") | |
with gr.Row(): | |
focal_length_x = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length X (pixels)") | |
focal_length_y = gr.Slider(minimum=100, maximum=1000, value=470.4, step=10, | |
label="Focal Length Y (pixels)") | |
with gr.Row(): | |
gray_depth_file = gr.File(label="Grayscale depth map", elem_id="download") | |
raw_file = gr.File(label="16-bit raw output (can be considered as disparity)", elem_id="download") | |
point_cloud_file = gr.File(label="Point Cloud (.ply)", elem_id="download") | |
# 3D Visualization | |
gr.Markdown("### 3D Point Cloud Visualization") | |
gr.Markdown("Enhanced 3D visualization using proper camera projection. Hover over points to see 3D coordinates.") | |
depth_3d_plot = gr.Plot(label="3D Point Cloud") | |
cmap = matplotlib.colormaps.get_cmap('Spectral_r') | |
def on_submit(image, num_points, focal_x, focal_y): | |
original_image = image.copy() | |
h, w = image.shape[:2] | |
depth = predict_depth(image[:, :, ::-1]) | |
raw_depth = Image.fromarray(depth.astype('uint16')) | |
tmp_raw_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
raw_depth.save(tmp_raw_depth.name) | |
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0 | |
depth = depth.astype(np.uint8) | |
colored_depth = (cmap(depth)[:, :, :3] * 255).astype(np.uint8) | |
gray_depth = Image.fromarray(depth) | |
tmp_gray_depth = tempfile.NamedTemporaryFile(suffix='.png', delete=False) | |
gray_depth.save(tmp_gray_depth.name) | |
# Create point cloud | |
pcd = create_point_cloud(original_image, depth, focal_x, focal_y, max_points=num_points) | |
tmp_pointcloud = tempfile.NamedTemporaryFile(suffix='.ply', delete=False) | |
o3d.io.write_point_cloud(tmp_pointcloud.name, pcd) | |
# Create enhanced 3D visualization | |
depth_3d = create_enhanced_3d_visualization(original_image, depth, max_points=num_points) | |
return [(original_image, colored_depth), tmp_gray_depth.name, tmp_raw_depth.name, tmp_pointcloud.name, depth_3d] | |
# Update slider when image is uploaded | |
input_image.change( | |
fn=update_slider_on_image_upload, | |
inputs=[input_image], | |
outputs=[points_slider] | |
) | |
submit.click(on_submit, inputs=[input_image, points_slider, focal_length_x, focal_length_y], | |
outputs=[depth_image_slider, gray_depth_file, raw_file, point_cloud_file, depth_3d_plot]) | |
if __name__ == '__main__': | |
demo.queue().launch() |