pranaya20's picture
Update app.py
0fa81a3 verified
raw
history blame
4.27 kB
import gradio as gr
from ultralytics import YOLO
import torch
import cv2
import numpy as np
import pandas as pd
import h3
import folium
from folium.plugins import MarkerCluster
from PIL import Image
import base64
import os
from datetime import datetime
# Load YOLO model
model = YOLO("yolov8n.pt")
# Try loading MiDaS depth estimation model
try:
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
transform = midas_transforms.small_transform
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
midas.to(device)
height_estimation_available = True
except:
height_estimation_available = False
# Placeholder species classifier (could be replaced by real model)
def classify_species(image):
return "Neem Tree" # Dummy result
# Create output CSV if it doesn't exist
if not os.path.exists("tree_data.csv"):
pd.DataFrame(columns=["timestamp", "latitude", "longitude", "h3_index", "species", "height_m", "image_path"]).to_csv("tree_data.csv", index=False)
# Tree analysis function
def analyze_tree(image, latitude, longitude):
image_np = np.array(image)
# Detect tree using YOLO
results = model(image_np)[0]
boxes = results.boxes.xyxy.cpu().numpy().astype(int)
classes = results.boxes.cls.cpu().numpy().astype(int)
height_m = "N/A"
# Depth estimation (if available)
if height_estimation_available:
input_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
input_image = cv2.resize(input_image, (256, 256))
input_tensor = transform(input_image).to(device)
with torch.no_grad():
prediction = midas(input_tensor.unsqueeze(0))
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=image_np.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
depth_map = prediction.cpu().numpy()
height_m = round(float(np.median(depth_map)), 2)
# Species classification (placeholder)
species = classify_species(image)
# H3 geolocation
h3_index = h3.geo_to_h3(float(latitude), float(longitude), 9)
# Save image
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
image_path = f"tree_{timestamp}.jpg"
image.save(image_path)
# Append to CSV
df = pd.read_csv("tree_data.csv")
df.loc[len(df)] = [timestamp, latitude, longitude, h3_index, species, height_m, image_path]
df.to_csv("tree_data.csv", index=False)
return f"Species: {species}\nEstimated Height: {height_m} meters\nH3 Index: {h3_index}"
# Generate map with markers
def generate_map():
df = pd.read_csv("tree_data.csv")
m = folium.Map(location=[20.5937, 78.9629], zoom_start=5)
marker_cluster = MarkerCluster().add_to(m)
for _, row in df.iterrows():
folium.Marker(
location=[row["latitude"], row["longitude"]],
popup=f"Species: {row['species']}\nHeight: {row['height_m']}\nH3: {row['h3_index']}",
).add_to(marker_cluster)
map_path = "tree_map.html"
m.save(map_path)
with open(map_path, "r") as f:
html_content = f.read()
return html_content
# Gradio interface
def js_location():
return """
async () => {
const pos = await new Promise((resolve, reject) => {
navigator.geolocation.getCurrentPosition(resolve, reject);
});
return [pos.coords.latitude, pos.coords.longitude];
}
"""
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
cam = gr.Image(source="webcam", streaming=True)
loc_btn = gr.Button("Get Location")
lat = gr.Textbox(label="Latitude")
lon = gr.Textbox(label="Longitude")
analyze_btn = gr.Button("Analyze Tree")
output = gr.Textbox(label="Results")
with gr.Column():
map_output = gr.HTML()
loc_btn.click(None, js=js_location(), outputs=[lat, lon])
analyze_btn.click(fn=analyze_tree, inputs=[cam, lat, lon], outputs=output)
analyze_btn.click(fn=generate_map, outputs=map_output)
if __name__ == "__main__":
demo.launch()