pranaya20's picture
Create app.py
8ae2ea0 verified
raw
history blame
2.26 kB
import gradio as gr
from ultralytics import YOLO
import torch
import cv2
import numpy as np
import wikipedia
from PIL import Image
# Load YOLO model for tree detection
yolo_model = YOLO("yolov8n.pt")
# Load MiDaS depth model
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small")
midas.to("cpu").eval()
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms").small
def estimate_tree_height(image):
# Convert image to OpenCV format
image = np.array(image)
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
# Object Detection (Tree)
results = yolo_model(image_rgb)
boxes = results[0].boxes.xyxy.cpu().numpy() # Get bounding boxes
labels = results[0].boxes.cls.cpu().numpy()
tree_boxes = [box for box, label in zip(boxes, labels) if int(label) == 0] # class 0 usually means 'person/tree'
if not tree_boxes:
return "No tree detected", None, None
x1, y1, x2, y2 = tree_boxes[0]
tree_crop = image[int(y1):int(y2), int(x1):int(x2)]
# Depth estimation
input_tensor = midas_transforms(Image.fromarray(image_rgb)).to("cpu")
with torch.no_grad():
depth_map = midas(input_tensor.unsqueeze(0))[0]
depth_resized = torch.nn.functional.interpolate(
depth_map.unsqueeze(0),
size=image_rgb.shape[:2],
mode="bicubic",
align_corners=False
).squeeze().cpu().numpy()
avg_depth = np.mean(depth_resized[int(y1):int(y2), int(x1):int(x2)])
estimated_height_m = avg_depth * 1.8 # arbitrary scaling for demo
# Wikipedia summary (simulate species info)
try:
summary = wikipedia.summary("tree", sentences=2)
except Exception:
summary = "Tree species information not available."
return f"Estimated Tree Height: {estimated_height_m:.2f} meters", Image.fromarray(tree_crop), summary
# Gradio Interface
demo = gr.Interface(
fn=estimate_tree_height,
inputs=gr.Image(type="pil"),
outputs=[
gr.Textbox(label="Tree Height Estimate"),
gr.Image(label="Detected Tree"),
gr.Textbox(label="Tree Species Info")
],
title="🌳 Tree Measurement App",
description="Capture a tree image to estimate its height and get basic species info."
)
demo.launch()