pranaya20's picture
Update app.py
aa4f4c1 verified
raw
history blame
11.8 kB
import gradio as gr
import torch
import cv2
import numpy as np
import requests
import json
from PIL import Image
import transformers
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
import wikipedia
import folium
from geopy.geocoders import Nominatim
import base64
from io import BytesIO
import tempfile
import os
class TreeAnalyzer:
def __init__(self):
self.setup_models()
self.geolocator = Nominatim(user_agent="tree_analyzer")
def setup_models(self):
"""Initialize all required models"""
print("Loading models...")
# Load MiDaS model for depth estimation
try:
self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small')
self.midas.eval()
self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
self.transform = self.midas_transforms.small_transform
print("βœ“ MiDaS model loaded successfully")
except Exception as e:
print(f"Error loading MiDaS: {e}")
self.midas = None
# Load plant classification model
try:
self.plant_classifier = pipeline(
"image-classification",
model="microsoft/resnet-50",
return_top_k=3
)
print("βœ“ Plant classifier loaded successfully")
except Exception as e:
print(f"Error loading plant classifier: {e}")
# Fallback to a more specific plant model if available
try:
self.plant_classifier = pipeline(
"image-classification",
model="google/vit-base-patch16-224",
return_top_k=3
)
print("βœ“ Fallback classifier loaded successfully")
except:
self.plant_classifier = None
print("βœ— Could not load plant classifier")
def estimate_tree_height(self, image, known_object_height=1.7):
"""
Estimate tree height using MiDaS depth estimation
known_object_height: assumed height of reference object (person = 1.7m)
"""
if self.midas is None:
return "MiDaS model not available", None
try:
# Convert PIL to OpenCV format
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
# Prepare image for MiDaS
input_batch = self.transform(img_cv).to(torch.float32)
# Generate depth map
with torch.no_grad():
prediction = self.midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img_cv.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
# Convert to numpy
depth_map = prediction.cpu().numpy()
# Normalize depth map for visualization
depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
depth_map_colored = cv2.applyColorMap(depth_map_normalized, cv2.COLORMAP_PLASMA)
# Simple height estimation (this is a simplified approach)
# In reality, you'd need camera calibration and more sophisticated methods
height, width = depth_map.shape
# Assume tree is in the center-upper portion of the image
tree_region = depth_map[int(height*0.1):int(height*0.8), int(width*0.3):int(width*0.7)]
# Calculate relative height based on depth variations
depth_range = np.max(tree_region) - np.min(tree_region)
# Rough estimation: scale based on depth range and image dimensions
estimated_height = (depth_range / np.max(depth_map)) * height * 0.02 # Scaling factor
estimated_height = max(2.0, min(50.0, estimated_height)) # Clamp between 2-50 meters
return f"Estimated height: {estimated_height:.1f} meters", depth_map_colored
except Exception as e:
return f"Error in height estimation: {str(e)}", None
def identify_tree_species(self, image):
"""Identify tree species using image classification"""
if self.plant_classifier is None:
return "Plant classifier not available", []
try:
# Get predictions
predictions = self.plant_classifier(image)
# Filter for tree-related predictions
tree_keywords = ['tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow', 'cherry', 'apple']
tree_predictions = []
for pred in predictions:
label = pred['label'].lower()
if any(keyword in label for keyword in tree_keywords):
tree_predictions.append(pred)
if not tree_predictions:
tree_predictions = predictions[:2] # Take top 2 if no tree-specific matches
# Get Wikipedia information for top prediction
species_info = []
for pred in tree_predictions[:2]:
try:
# Search Wikipedia
wiki_results = wikipedia.search(pred['label'] + " tree", results=1)
if wiki_results:
page = wikipedia.page(wiki_results[0])
summary = wikipedia.summary(wiki_results[0], sentences=2)
species_info.append({
'species': pred['label'],
'confidence': pred['score'],
'wiki_title': page.title,
'summary': summary,
'url': page.url
})
except:
species_info.append({
'species': pred['label'],
'confidence': pred['score'],
'wiki_title': 'Not found',
'summary': 'No Wikipedia information available',
'url': None
})
return "Species identified successfully", species_info
except Exception as e:
return f"Error in species identification: {str(e)}", []
def get_location_info(self, latitude, longitude):
"""Get location information from coordinates"""
if latitude is None or longitude is None:
return "Location not provided", None
try:
location = self.geolocator.reverse(f"{latitude}, {longitude}")
# Create a map
m = folium.Map(location=[latitude, longitude], zoom_start=15)
folium.Marker(
[latitude, longitude],
popup=f"Tree Location<br>Lat: {latitude:.6f}<br>Lon: {longitude:.6f}",
tooltip="Tree Location"
).add_to(m)
# Save map to temporary file
map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html')
m.save(map_file.name)
address = location.address if location else "Address not found"
return f"Location: {address}", map_file.name
except Exception as e:
return f"Error getting location info: {str(e)}", None
def analyze_tree(image, latitude, longitude):
"""Main function to analyze tree from image and location"""
if image is None:
return "Please upload an image", "", "", "", None
analyzer = TreeAnalyzer()
# Analyze height
height_result, depth_map = analyzer.estimate_tree_height(image)
# Identify species
species_result, species_info = analyzer.identify_tree_species(image)
# Get location info
location_result, map_file = analyzer.get_location_info(latitude, longitude)
# Format species information
species_text = ""
if species_info:
for info in species_info:
species_text += f"**{info['species']}** (Confidence: {info['confidence']:.2f})\n"
species_text += f"*{info['summary']}*\n"
if info['url']:
species_text += f"[Wikipedia Link]({info['url']})\n"
species_text += "\n"
# Return results
return (
height_result,
species_result,
species_text,
location_result,
map_file
)
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Tree Analysis App", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🌳 Tree Analysis App
Upload an image of a tree and optionally provide GPS coordinates to get:
- **Tree height estimation** using MiDaS depth estimation
- **Species identification** with Wikipedia information
- **Location mapping** of where the tree was captured
""")
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Tree Image")
with gr.Row():
latitude_input = gr.Number(
label="Latitude (optional)",
placeholder="e.g., 40.7128",
info="GPS latitude coordinate"
)
longitude_input = gr.Number(
label="Longitude (optional)",
placeholder="e.g., -74.0060",
info="GPS longitude coordinate"
)
analyze_btn = gr.Button("πŸ” Analyze Tree", variant="primary")
with gr.Column(scale=2):
with gr.Tab("Results"):
height_output = gr.Textbox(label="Height Estimation", lines=2)
species_status = gr.Textbox(label="Species Identification Status", lines=1)
species_output = gr.Markdown(label="Species Information")
location_output = gr.Textbox(label="Location Information", lines=2)
with gr.Tab("Location Map"):
map_output = gr.HTML(label="Location Map")
# Connect the analyze button
analyze_btn.click(
fn=analyze_tree,
inputs=[image_input, latitude_input, longitude_input],
outputs=[height_output, species_status, species_output, location_output, map_output]
)
# Example section
gr.Markdown("""
## πŸ“± Usage Tips:
1. **Take a clear photo** of the tree with good lighting
2. **Include reference objects** (like people) for better height estimation
3. **Enable GPS** on your phone and note the coordinates
4. **Upload the image** and enter GPS coordinates if available
5. **Click Analyze** to get comprehensive tree information
## πŸ”— Sharing:
This app generates a shareable link that you can send to others!
""")
return demo
# Main execution
if __name__ == "__main__":
# Create and launch the interface
demo = create_interface()
# Launch with sharing enabled to generate a public link
demo.launch(
share=True, # This creates a shareable public link
server_name="0.0.0.0",
server_port=7860,
show_error=True
)