tree_height2 / app.py
nagasurendra's picture
Update app.py
bc2a181 verified
raw
history blame
4.38 kB
import cv2
import numpy as np
import gradio as gr
from ultralytics import YOLO
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from PIL import Image
# Load YOLO model for tree detection
# Replace "path_to_your_yolo_model.pt" with your model path (e.g., local or Hugging Face Hub)
yolo_model = YOLO("./data/best.pt") # Update with your YOLO model path
# Load depth estimation model and processor from Hugging Face
processor = AutoImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = AutoModelForDepthEstimation.from_pretrained("Intel/dpt-large")
# Function to process image and estimate tree heights
def process_image(image, focal_length_mm=3.6, sensor_height_mm=4.8, depth_scale=100):
"""
Process an input image to detect trees and estimate their heights.
Args:
image: PIL Image from Gradio
focal_length_mm: Camera focal length in millimeters (default: 3.6)
sensor_height_mm: Camera sensor height in millimeters (default: 4.8)
depth_scale: Scaling factor to convert depth map to centimeters (default: 100)
Returns:
Annotated image and JSON with tree heights
"""
# Convert PIL image to OpenCV format (BGR)
image_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
image_height = image_cv.shape[0] # Image height in pixels
# Step 1: Run YOLO to detect trees
results = yolo_model(image_cv)
boxes = results[0].boxes.xyxy.cpu().numpy() # Bounding boxes [x_min, y_min, x_max, y_max]
# Step 2: Prepare image for depth estimation
# Convert OpenCV image (BGR) to PIL for transformers
image_pil = Image.fromarray(cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB))
# Preprocess image for depth model
inputs = processor(images=image_pil, return_tensors="pt")
# Step 3: Run depth estimation
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
# Resize depth map to match input image size
depth_map = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=(image_cv.shape[0], image_cv.shape[1]),
mode="bicubic",
align_corners=False,
).squeeze().cpu().numpy()
# Step 4: Process each detected tree
output = []
for box in boxes:
x_min, y_min, x_max, y_max = map(int, box)
h_pixel = y_max - y_min # Bounding box height in pixels
# Extract depth for the tree’s bounding box
depth_region = depth_map[y_min:y_max, x_min:x_max]
avg_depth = np.mean(depth_region) # Average depth (relative units)
# Convert depth to centimeters using scaling factor
distance_cm = avg_depth * depth_scale # Tune depth_scale based on testing
# Calculate tree height in centimeters
# Formula: H = (h_pixel * D * sensor_height) / (focal_length * image_height)
tree_height_cm = (h_pixel * distance_cm * sensor_height_mm) / (focal_length_mm * image_height)
tree_height_cm = round(tree_height_cm, 2) # Round to 2 decimal places
output.append({
"box": (x_min, y_min, x_max, y_max),
"height_cm": tree_height_cm
})
# Step 5: Draw results on the image
for item in output:
x_min, y_min, x_max, y_max = item["box"]
# Draw bounding box
cv2.rectangle(image_cv, (x_min, y_min), (x_max, y_max), (0, 255, 0), 2)
# Add height text
cv2.putText(image_cv, f"Height: {item['height_cm']} cm", (x_min, y_min - 10),
cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# Convert back to RGB for Gradio
image_rgb = cv2.cvtColor(image_cv, cv2.COLOR_BGR2RGB)
return image_rgb, output
# Create Gradio interface
iface = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="Upload Image"),
gr.Number(label="Focal Length (mm)", value=3.6),
gr.Number(label="Sensor Height (mm)", value=4.8),
gr.Number(label="Depth Scale Factor", value=100)
],
outputs=[
gr.Image(label="Detected Trees with Heights"),
gr.JSON(label="Tree Heights (cm)")
],
title="Tree Detection and Height Estimation",
description="Upload an image to detect trees and estimate their heights in centimeters. Adjust camera parameters and depth scale as needed."
)
# Launch the interface
iface.launch()