three / app.py
sudo-soldier's picture
Update app.py
fe51ca1 verified
raw
history blame
3.94 kB
import gradio as gr
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
import torch
import numpy as np
from PIL import Image
import open3d as o3d
from pathlib import Path
# Load model and feature extractor
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
def process_image(image_path):
image_path = Path(image_path) if isinstance(image_path, str) else image_path
try:
image_raw = Image.open(image_path).convert("RGB")
except Exception as e:
return f"Error loading image: {e}"
# Resize while maintaining aspect ratio
image = image_raw.resize(
(800, int(800 * image_raw.size[1] / image_raw.size[0])),
Image.Resampling.LANCZOS
)
encoding = feature_extractor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**encoding)
predicted_depth = outputs.predicted_depth
# Normalize depth image
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
if np.max(output) > 0:
depth_image = (output * 255 / np.max(output)).astype('uint8')
else:
depth_image = np.zeros_like(output, dtype='uint8') # Handle empty output
try:
gltf_path = create_3d_obj(np.array(image), depth_image, image_path)
except Exception:
gltf_path = create_3d_obj(np.array(image), depth_image, image_path, depth=8)
return Image.fromarray(depth_image), gltf_path, gltf_path
def create_3d_obj(rgb_image, depth_image, image_path, depth=10):
depth_o3d = o3d.geometry.Image(depth_image)
image_o3d = o3d.geometry.Image(rgb_image)
rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
image_o3d, depth_o3d, convert_rgb_to_intensity=False)
w, h = depth_image.shape[1], depth_image.shape[0]
camera_intrinsic = o3d.camera.PinholeCameraIntrinsic()
camera_intrinsic.set_intrinsics(w, h, 500, 500, w / 2, h / 2)
pcd = o3d.geometry.PointCloud.create_from_rgbd_image(rgbd_image, camera_intrinsic)
pcd.estimate_normals(
search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.01, max_nn=30))
pcd.orient_normals_towards_camera_location(camera_location=np.array([0., 0., 1000.]))
with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug):
mesh_raw, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(
pcd, depth=depth, width=0, scale=1.1, linear_fit=True)
voxel_size = max(mesh_raw.get_max_bound() - mesh_raw.get_min_bound()) / 256
mesh = mesh_raw.simplify_vertex_clustering(voxel_size=voxel_size)
bbox = pcd.get_axis_aligned_bounding_box()
mesh_crop = mesh.crop(bbox)
gltf_path = f'./{image_path.stem}.gltf'
o3d.io.write_triangle_mesh(gltf_path, mesh_crop, write_triangle_uvs=True)
return gltf_path
title = "Zero-shot Depth Estimation with DPT + 3D Model Preview"
description = "Upload an image to generate a depth map and reconstruct a 3D model in .gltf format."
with gr.Blocks() as demo:
gr.Markdown(f"## {title}")
gr.Markdown(description)
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="Upload Image")
generate_button = gr.Button("Generate 3D Model")
with gr.Column():
depth_output = gr.Image(label="Predicted Depth", type="pil")
with gr.Row():
model_output = gr.Model3D(label="3D Model Preview (GLTF)")
with gr.Row():
file_output = gr.File(label="Download 3D GLTF File")
generate_button.click(fn=process_image, inputs=[image_input], outputs=[depth_output, model_output, file_output])
if __name__ == "__main__":
demo.launch()