Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
import requests | |
import json | |
from PIL import Image | |
import transformers | |
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification | |
import wikipedia | |
import folium | |
from geopy.geocoders import Nominatim | |
import base64 | |
from io import BytesIO | |
import tempfile | |
import os | |
class TreeAnalyzer: | |
def __init__(self): | |
self.setup_models() | |
self.geolocator = Nominatim(user_agent="tree_analyzer") | |
def setup_models(self): | |
"""Initialize all required models""" | |
print("Loading models...") | |
# Load MiDaS model for depth estimation | |
try: | |
self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small') | |
self.midas.eval() | |
self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms') | |
self.transform = self.midas_transforms.small_transform | |
print("β MiDaS model loaded successfully") | |
except Exception as e: | |
print(f"Error loading MiDaS: {e}") | |
self.midas = None | |
# Load plant classification model | |
try: | |
self.plant_classifier = pipeline( | |
"image-classification", | |
model="microsoft/resnet-50", | |
return_top_k=3 | |
) | |
print("β Plant classifier loaded successfully") | |
except Exception as e: | |
print(f"Error loading plant classifier: {e}") | |
# Fallback to a more specific plant model if available | |
try: | |
self.plant_classifier = pipeline( | |
"image-classification", | |
model="google/vit-base-patch16-224", | |
return_top_k=3 | |
) | |
print("β Fallback classifier loaded successfully") | |
except: | |
self.plant_classifier = None | |
print("β Could not load plant classifier") | |
def estimate_tree_height(self, image, known_object_height=1.7): | |
""" | |
Estimate tree height using MiDaS depth estimation | |
known_object_height: assumed height of reference object (person = 1.7m) | |
""" | |
if self.midas is None: | |
return "MiDaS model not available", None | |
try: | |
# Convert PIL to OpenCV format | |
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
# Prepare image for MiDaS | |
input_batch = self.transform(img_cv).to(torch.float32) | |
# Generate depth map | |
with torch.no_grad(): | |
prediction = self.midas(input_batch) | |
prediction = torch.nn.functional.interpolate( | |
prediction.unsqueeze(1), | |
size=img_cv.shape[:2], | |
mode="bicubic", | |
align_corners=False, | |
).squeeze() | |
# Convert to numpy | |
depth_map = prediction.cpu().numpy() | |
# Normalize depth map for visualization | |
depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U) | |
depth_map_colored = cv2.applyColorMap(depth_map_normalized, cv2.COLORMAP_PLASMA) | |
# Simple height estimation (this is a simplified approach) | |
# In reality, you'd need camera calibration and more sophisticated methods | |
height, width = depth_map.shape | |
# Assume tree is in the center-upper portion of the image | |
tree_region = depth_map[int(height*0.1):int(height*0.8), int(width*0.3):int(width*0.7)] | |
# Calculate relative height based on depth variations | |
depth_range = np.max(tree_region) - np.min(tree_region) | |
# Rough estimation: scale based on depth range and image dimensions | |
estimated_height = (depth_range / np.max(depth_map)) * height * 0.02 # Scaling factor | |
estimated_height = max(2.0, min(50.0, estimated_height)) # Clamp between 2-50 meters | |
return f"Estimated height: {estimated_height:.1f} meters", depth_map_colored | |
except Exception as e: | |
return f"Error in height estimation: {str(e)}", None | |
def identify_tree_species(self, image): | |
"""Identify tree species using image classification""" | |
if self.plant_classifier is None: | |
return "Plant classifier not available", [] | |
try: | |
# Get predictions | |
predictions = self.plant_classifier(image) | |
# Filter for tree-related predictions | |
tree_keywords = ['tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow', 'cherry', 'apple'] | |
tree_predictions = [] | |
for pred in predictions: | |
label = pred['label'].lower() | |
if any(keyword in label for keyword in tree_keywords): | |
tree_predictions.append(pred) | |
if not tree_predictions: | |
tree_predictions = predictions[:2] # Take top 2 if no tree-specific matches | |
# Get Wikipedia information for top prediction | |
species_info = [] | |
for pred in tree_predictions[:2]: | |
try: | |
# Search Wikipedia | |
wiki_results = wikipedia.search(pred['label'] + " tree", results=1) | |
if wiki_results: | |
page = wikipedia.page(wiki_results[0]) | |
summary = wikipedia.summary(wiki_results[0], sentences=2) | |
species_info.append({ | |
'species': pred['label'], | |
'confidence': pred['score'], | |
'wiki_title': page.title, | |
'summary': summary, | |
'url': page.url | |
}) | |
except: | |
species_info.append({ | |
'species': pred['label'], | |
'confidence': pred['score'], | |
'wiki_title': 'Not found', | |
'summary': 'No Wikipedia information available', | |
'url': None | |
}) | |
return "Species identified successfully", species_info | |
except Exception as e: | |
return f"Error in species identification: {str(e)}", [] | |
def get_location_info(self, latitude, longitude): | |
"""Get location information from coordinates""" | |
if latitude is None or longitude is None: | |
return "Location not provided", None | |
try: | |
location = self.geolocator.reverse(f"{latitude}, {longitude}") | |
# Create a map | |
m = folium.Map(location=[latitude, longitude], zoom_start=15) | |
folium.Marker( | |
[latitude, longitude], | |
popup=f"Tree Location<br>Lat: {latitude:.6f}<br>Lon: {longitude:.6f}", | |
tooltip="Tree Location" | |
).add_to(m) | |
# Save map to temporary file | |
map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html') | |
m.save(map_file.name) | |
address = location.address if location else "Address not found" | |
return f"Location: {address}", map_file.name | |
except Exception as e: | |
return f"Error getting location info: {str(e)}", None | |
def analyze_tree(image, latitude, longitude): | |
"""Main function to analyze tree from image and location""" | |
if image is None: | |
return "Please upload an image", "", "", "", None | |
analyzer = TreeAnalyzer() | |
# Analyze height | |
height_result, depth_map = analyzer.estimate_tree_height(image) | |
# Identify species | |
species_result, species_info = analyzer.identify_tree_species(image) | |
# Get location info | |
location_result, map_file = analyzer.get_location_info(latitude, longitude) | |
# Format species information | |
species_text = "" | |
if species_info: | |
for info in species_info: | |
species_text += f"**{info['species']}** (Confidence: {info['confidence']:.2f})\n" | |
species_text += f"*{info['summary']}*\n" | |
if info['url']: | |
species_text += f"[Wikipedia Link]({info['url']})\n" | |
species_text += "\n" | |
# Return results | |
return ( | |
height_result, | |
species_result, | |
species_text, | |
location_result, | |
map_file | |
) | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Tree Analysis App", theme=gr.themes.Soft()) as demo: | |
gr.Markdown(""" | |
# π³ Tree Analysis App | |
Upload an image of a tree and optionally provide GPS coordinates to get: | |
- **Tree height estimation** using MiDaS depth estimation | |
- **Species identification** with Wikipedia information | |
- **Location mapping** of where the tree was captured | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image(type="pil", label="Upload Tree Image") | |
with gr.Row(): | |
latitude_input = gr.Number( | |
label="Latitude (optional)", | |
placeholder="e.g., 40.7128", | |
info="GPS latitude coordinate" | |
) | |
longitude_input = gr.Number( | |
label="Longitude (optional)", | |
placeholder="e.g., -74.0060", | |
info="GPS longitude coordinate" | |
) | |
analyze_btn = gr.Button("π Analyze Tree", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Tab("Results"): | |
height_output = gr.Textbox(label="Height Estimation", lines=2) | |
species_status = gr.Textbox(label="Species Identification Status", lines=1) | |
species_output = gr.Markdown(label="Species Information") | |
location_output = gr.Textbox(label="Location Information", lines=2) | |
with gr.Tab("Location Map"): | |
map_output = gr.HTML(label="Location Map") | |
# Connect the analyze button | |
analyze_btn.click( | |
fn=analyze_tree, | |
inputs=[image_input, latitude_input, longitude_input], | |
outputs=[height_output, species_status, species_output, location_output, map_output] | |
) | |
# Example section | |
gr.Markdown(""" | |
## π± Usage Tips: | |
1. **Take a clear photo** of the tree with good lighting | |
2. **Include reference objects** (like people) for better height estimation | |
3. **Enable GPS** on your phone and note the coordinates | |
4. **Upload the image** and enter GPS coordinates if available | |
5. **Click Analyze** to get comprehensive tree information | |
## π Sharing: | |
This app generates a shareable link that you can send to others! | |
""") | |
return demo | |
# Main execution | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
# Launch with sharing enabled to generate a public link | |
demo.launch( | |
share=True, # This creates a shareable public link | |
server_name="0.0.0.0", | |
server_port=7860, | |
show_error=True | |
) | |