Spaces:
Runtime error
Runtime error
import gradio as gr | |
from ultralytics import YOLO | |
import torch | |
import cv2 | |
import numpy as np | |
import pandas as pd | |
import h3 | |
import folium | |
from folium.plugins import MarkerCluster | |
from PIL import Image | |
import base64 | |
import os | |
from datetime import datetime | |
# Load YOLO model | |
model = YOLO("yolov8n.pt") | |
# Try loading MiDaS depth estimation model | |
try: | |
midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True) | |
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True) | |
transform = midas_transforms.small_transform | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
midas.to(device) | |
height_estimation_available = True | |
except: | |
height_estimation_available = False | |
# Placeholder species classifier (could be replaced by real model) | |
def classify_species(image): | |
return "Neem Tree" # Dummy result | |
# Create output CSV if it doesn't exist | |
if not os.path.exists("tree_data.csv"): | |
pd.DataFrame(columns=["timestamp", "latitude", "longitude", "h3_index", "species", "height_m", "image_path"]).to_csv("tree_data.csv", index=False) | |
# Tree analysis function | |
def analyze_tree(image, latitude, longitude): | |
image_np = np.array(image) | |
# Detect tree using YOLO | |
results = model(image_np)[0] | |
boxes = results.boxes.xyxy.cpu().numpy().astype(int) | |
classes = results.boxes.cls.cpu().numpy().astype(int) | |
height_m = "N/A" | |
# Depth estimation (if available) | |
if height_estimation_available: | |
input_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
input_image = cv2.resize(input_image, (256, 256)) | |
input_tensor = transform(input_image).to(device) | |
with torch.no_grad(): | |
prediction = midas(input_tensor.unsqueeze(0)) | |
prediction = torch.nn.functional.interpolate( | |
prediction.unsqueeze(1), | |
size=image_np.shape[:2], | |
mode="bicubic", | |
align_corners=False, | |
).squeeze() | |
depth_map = prediction.cpu().numpy() | |
height_m = round(float(np.median(depth_map)), 2) | |
# Species classification (placeholder) | |
species = classify_species(image) | |
# H3 geolocation | |
h3_index = h3.geo_to_h3(float(latitude), float(longitude), 9) | |
# Save image | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
image_path = f"tree_{timestamp}.jpg" | |
image.save(image_path) | |
# Append to CSV | |
df = pd.read_csv("tree_data.csv") | |
df.loc[len(df)] = [timestamp, latitude, longitude, h3_index, species, height_m, image_path] | |
df.to_csv("tree_data.csv", index=False) | |
return f"Species: {species}\nEstimated Height: {height_m} meters\nH3 Index: {h3_index}" | |
# Generate map with markers | |
def generate_map(): | |
df = pd.read_csv("tree_data.csv") | |
m = folium.Map(location=[20.5937, 78.9629], zoom_start=5) | |
marker_cluster = MarkerCluster().add_to(m) | |
for _, row in df.iterrows(): | |
folium.Marker( | |
location=[row["latitude"], row["longitude"]], | |
popup=f"Species: {row['species']}\nHeight: {row['height_m']}\nH3: {row['h3_index']}", | |
).add_to(marker_cluster) | |
map_path = "tree_map.html" | |
m.save(map_path) | |
with open(map_path, "r") as f: | |
html_content = f.read() | |
return html_content | |
# Gradio interface | |
def js_location(): | |
return """ | |
async () => { | |
const pos = await new Promise((resolve, reject) => { | |
navigator.geolocation.getCurrentPosition(resolve, reject); | |
}); | |
return [pos.coords.latitude, pos.coords.longitude]; | |
} | |
""" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
cam = gr.Image(source="webcam", streaming=True) | |
loc_btn = gr.Button("Get Location") | |
lat = gr.Textbox(label="Latitude") | |
lon = gr.Textbox(label="Longitude") | |
analyze_btn = gr.Button("Analyze Tree") | |
output = gr.Textbox(label="Results") | |
with gr.Column(): | |
map_output = gr.HTML() | |
loc_btn.click(None, js=js_location(), outputs=[lat, lon]) | |
analyze_btn.click(fn=analyze_tree, inputs=[cam, lat, lon], outputs=output) | |
analyze_btn.click(fn=generate_map, outputs=map_output) | |
if __name__ == "__main__": | |
demo.launch() | |