Spaces:
Sleeping
Sleeping
import cv2 | |
import numpy as np | |
import gradio as gr | |
from ultralytics import YOLO | |
from transformers import AutoImageProcessor, AutoModelForDepthEstimation | |
from PIL import Image | |
import torch # Added torch import for depth estimation | |
# Load YOLO model for tree detection | |
# Replace with your model path (local or Hugging Face Hub) | |
yolo_model = YOLO("./data/best.pt") # Update with your YOLO model path | |
# Load depth estimation model and processor from Hugging Face | |
processor = AutoImageProcessor.from_pretrained("Intel/dpt-large") | |
depth_model = AutoModelForDepthEstimation.from_pretrained("Intel/dpt-large") | |
# Function to process image and estimate tree heights | |
def process_image(image, focal_length_mm=3.6, sensor_height_mm=4.8, depth_scale=100): | |
""" | |
Process an input image to detect trees and estimate their heights. | |
Args: | |
image: PIL Image from Gradio | |
focal_length_mm: Camera focal length in millimeters (default: 3.6) | |
sensor_height_mm: Camera sensor height in millimeters (default: 4.8) | |
depth_scale: Scaling factor to convert depth map to centimeters (default: 100) | |
Returns: | |
Annotated image and JSON with tree heights | |
""" | |
# Convert PIL image to OpenCV format (BGR) | |
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
image_height = image_cv.shape[0] # Image height in pixels | |
# Step 1: Run YOLO to detect trees | |
results = yolo_model(image_cv) | |
boxes = results[0].boxes.xyxy.cpu().numpy() # Bounding boxes [x_min, y_min, x_max, y_max] | |
# Step 2: Prepare image for depth estimation | |
# Convert OpenCV image (BGR) to PIL for transformers | |
image_pil = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)) | |
# Preprocess image for depth model | |
inputs = processor(images=image_pil, return_tensors="pt") | |
# Step 3: Run depth estimation | |
with torch.no_grad(): | |
outputs = depth_model(**inputs) | |
predicted_depth = outputs.predicted_depth | |
# Resize depth map to match input image size | |
depth_map = torch.nn.functional.interpolate( | |
predicted_depth.unsqueeze(1), | |
size=(image_cv.shape[0], image_cv.shape[1]), | |
mode="bicubic", | |
align_corners=False, | |
).squeeze().cpu().numpy() | |
# Step 4: Process each detected tree | |
output = [] | |
for box in boxes: | |
x_min, y_min, x_max, y_max = map(int, box) | |
h_pixel = y_max - y_min # Bounding box height in pixels | |
# Extract depth for the tree’s bounding box | |
depth_region = depth_map[y_min:y_max, x_min:x_max] | |
avg_depth = np.mean(depth_region) # Average depth (relative units) | |
# Convert depth to centimeters using scaling factor | |
distance_cm = avg_depth * depth_scale # Tune depth_scale based on testing | |
# Calculate tree height in centimeters | |
# Formula: H = (h_pixel * D * sensor_height) / (focal_length * image_height) | |
tree_height_cm = (h_pixel * distance_cm * sensor_height_mm) / (focal_length_mm * image_height) | |
tree_height_cm = round(tree_height_cm, 2) # Round to 2 decimal places | |
output.append({ | |
"box": (x_min, y_min, x_max, y_max), | |
"height_cm": tree_height_cm | |
}) | |
# Step 5: Draw results on the image | |
for item in output: | |
x_min, y_min, x_max, y_max = item["box"] | |
# Draw bounding box | |
cv2.rectangle(image_cv, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2) | |
# Add height text | |
cv2.putText(image_cv, f"Height: {item['height_cm']} cm", (x_min, y_min - 10), | |
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
# Convert back to RGB for Gradio | |
image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB) | |
return image_rgb, output | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=process_image, | |
inputs=[ | |
gr.Image(type="pil", label="Upload Image"), | |
gr.Number(label="Focal Length (mm)", value=3.6), | |
gr.Number(label="Sensor Height (mm)", value=4.8), | |
gr.Number(label="Depth Scale Factor", value=100) | |
], | |
outputs=[ | |
gr.Image(label="Detected Trees with Heights"), | |
gr.JSON(label="Tree Heights (cm)") | |
], | |
title="Tree Detection and Height Estimation", | |
description="Upload an image to detect trees and estimate their heights in centimeters. Adjust camera parameters and depth scale as needed." | |
) | |
# Launch the interface locally | |
iface.launch(server_name="0.0.0.0", server_port=7860) |