Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import cv2 | |
import numpy as np | |
from PIL import Image | |
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification | |
import wikipedia | |
import folium | |
import tempfile | |
import os | |
import logging | |
import warnings | |
warnings.filterwarnings("ignore") | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
class TreeAnalyzer: | |
def __init__(self): | |
self.setup_models() | |
def setup_models(self): | |
"""Initialize models optimized for HF Spaces""" | |
logger.info("Loading models for HF Spaces...") | |
# Load depth estimation model | |
self.midas = None | |
try: | |
self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small', trust_repo=True) | |
self.midas.eval() | |
self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms', trust_repo=True) | |
self.transform = self.midas_transforms.small_transform | |
logger.info("β MiDaS loaded") | |
except Exception as e: | |
logger.error(f"MiDaS failed: {e}") | |
# Load plant classification model | |
self.plant_classifier = None | |
models_to_try = [ | |
"google/vit-base-patch16-224", | |
"microsoft/resnet-50", | |
"facebook/convnext-tiny-224" | |
] | |
for model_name in models_to_try: | |
try: | |
self.plant_classifier = pipeline( | |
"image-classification", | |
model=model_name, | |
return_top_k=10 | |
) | |
logger.info(f"β Loaded classifier: {model_name}") | |
break | |
except Exception as e: | |
logger.warning(f"Failed to load {model_name}: {e}") | |
continue | |
def estimate_tree_height(self, image): | |
"""Estimate tree height using depth estimation""" | |
if self.midas is None: | |
return "Height estimation not available (MiDaS model failed to load)" | |
try: | |
# Convert and resize image | |
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
h, w = img_cv.shape[:2] | |
# Resize for memory efficiency | |
if h > 384 or w > 384: | |
scale = min(384/h, 384/w) | |
new_h, new_w = int(h*scale), int(w*scale) | |
img_cv = cv2.resize(img_cv, (new_w, new_h)) | |
# Process with MiDaS | |
input_batch = self.transform(img_cv) | |
with torch.no_grad(): | |
prediction = self.midas(input_batch) | |
prediction = torch.nn.functional.interpolate( | |
prediction.unsqueeze(1), | |
size=(img_cv.shape[0], img_cv.shape[1]), | |
mode="bicubic", | |
align_corners=False, | |
).squeeze() | |
depth_map = prediction.cpu().numpy() | |
# Simple height estimation | |
h_img, w_img = depth_map.shape | |
center_region = depth_map[h_img//4:3*h_img//4, w_img//4:3*w_img//4] | |
if center_region.size > 0: | |
depth_range = np.max(center_region) - np.min(center_region) | |
height_ratio = center_region.shape[0] / h_img | |
estimated_height = max(1.5, min(50.0, (depth_range * height_ratio * 30))) | |
return f"Estimated height: {estimated_height:.1f} meters\n(Approximate estimate based on image depth analysis)" | |
else: | |
return "Could not estimate height from this image" | |
except Exception as e: | |
logger.error(f"Height estimation error: {e}") | |
return f"Height estimation failed: {str(e)}" | |
def identify_tree_species(self, image): | |
"""Identify tree species with better filtering""" | |
if self.plant_classifier is None: | |
return "Species identification not available (classifier failed to load)", [] | |
try: | |
# Resize image for processing | |
if image.size[0] > 224 or image.size[1] > 224: | |
image = image.resize((224, 224), Image.Resampling.LANCZOS) | |
# Get predictions | |
predictions = self.plant_classifier(image) | |
# Enhanced plant/tree keywords | |
plant_keywords = [ | |
# Trees | |
'tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow', | |
'cherry', 'apple', 'spruce', 'poplar', 'ash', 'elm', 'beech', 'sycamore', | |
'acacia', 'eucalyptus', 'magnolia', 'chestnut', 'walnut', 'hickory', | |
'cypress', 'juniper', 'redwood', 'bamboo', 'mahogany', 'teak', | |
# Plants and botanical terms | |
'plant', 'leaf', 'leaves', 'branch', 'bark', 'forest', 'wood', 'botanical', | |
'flora', 'foliage', 'evergreen', 'deciduous', 'conifer', 'hardwood', | |
'softwood', 'timber', 'shrub', 'bush', 'vine', 'fern', 'moss', | |
# Specific species indicators | |
'quercus', 'pinus', 'acer', 'betula', 'fagus', 'tilia', 'fraxinus', | |
'platanus', 'castanea', 'juglans', 'carya', 'ulmus', 'salix' | |
] | |
# Process and score predictions | |
species_candidates = [] | |
for pred in predictions: | |
label = pred['label'].lower() | |
confidence = pred['score'] | |
# Calculate plant relevance score | |
plant_score = sum(1 for keyword in plant_keywords if keyword in label) | |
is_plant_related = plant_score > 0 | |
# Get Wikipedia info | |
wiki_info = self.get_wikipedia_info(pred['label']) | |
species_candidates.append({ | |
'species': pred['label'], | |
'confidence': confidence, | |
'plant_score': plant_score, | |
'is_plant_related': is_plant_related, | |
'wiki_info': wiki_info | |
}) | |
# Sort by plant relevance and confidence | |
species_candidates.sort(key=lambda x: (x['plant_score'], x['confidence']), reverse=True) | |
# Return top candidates | |
final_results = species_candidates[:3] | |
if any(result['is_plant_related'] for result in final_results): | |
return "Species identification completed", final_results | |
else: | |
return "Possible species identified (may not be plants)", final_results | |
except Exception as e: | |
logger.error(f"Species identification error: {e}") | |
return f"Species identification failed: {str(e)}", [] | |
def get_wikipedia_info(self, species_name): | |
"""Get Wikipedia information with better error handling""" | |
try: | |
# Clean species name | |
clean_name = species_name.split(',')[0].split('(')[0].strip() | |
search_queries = [ | |
clean_name, | |
f"{clean_name} tree", | |
f"{clean_name} plant", | |
f"{clean_name} species" | |
] | |
for query in search_queries: | |
try: | |
results = wikipedia.search(query, results=2) | |
if results: | |
for result in results: | |
try: | |
page = wikipedia.page(result, auto_suggest=False) | |
summary = wikipedia.summary(result, sentences=2, auto_suggest=False) | |
return { | |
'title': page.title, | |
'summary': summary, | |
'url': page.url | |
} | |
except: | |
continue | |
except: | |
continue | |
return { | |
'title': 'No information found', | |
'summary': f'Wikipedia information not available for {species_name}', | |
'url': None | |
} | |
except Exception as e: | |
return { | |
'title': 'Error', | |
'summary': f'Could not retrieve information: {str(e)}', | |
'url': None | |
} | |
def analyze_tree(image, latitude, longitude): | |
"""Main analysis function""" | |
if image is None: | |
return "Please upload an image", "", "", "", "" | |
try: | |
analyzer = TreeAnalyzer() | |
# Height estimation | |
height_result = analyzer.estimate_tree_height(image) | |
# Species identification | |
species_status, species_info = analyzer.identify_tree_species(image) | |
# Format species results | |
species_text = "" | |
if species_info: | |
for i, info in enumerate(species_info, 1): | |
species_text += f"## {i}. {info['species']}\n" | |
species_text += f"**Confidence:** {info['confidence']:.3f}\n" | |
species_text += f"**Plant-related:** {'Yes' if info['is_plant_related'] else 'Uncertain'}\n" | |
wiki = info['wiki_info'] | |
species_text += f"**Wikipedia:** {wiki['title']}\n" | |
species_text += f"{wiki['summary']}\n" | |
if wiki['url']: | |
species_text += f"π [Read more]({wiki['url']})\n" | |
species_text += "\n---\n" | |
else: | |
species_text = "No species information could be determined from this image." | |
# Location info | |
location_result = "" | |
map_html = "" | |
if latitude is not None and longitude is not None: | |
try: | |
location_result = f"Coordinates: {latitude:.6f}, {longitude:.6f}" | |
# Create map | |
m = folium.Map(location=[latitude, longitude], zoom_start=15) | |
folium.Marker( | |
[latitude, longitude], | |
popup=f"Tree Location<br>{latitude:.6f}, {longitude:.6f}", | |
tooltip="Tree Location" | |
).add_to(m) | |
# Save map | |
map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w') | |
m.save(map_file.name) | |
map_file.close() | |
with open(map_file.name, 'r', encoding='utf-8') as f: | |
map_html = f.read() | |
os.unlink(map_file.name) | |
except Exception as e: | |
location_result = f"Error processing location: {str(e)}" | |
map_html = "<p>Could not generate map</p>" | |
else: | |
location_result = "No GPS coordinates provided" | |
map_html = "<p>No location data available</p>" | |
return species_status, height_result, species_text, location_result, map_html | |
except Exception as e: | |
logger.error(f"Analysis failed: {e}") | |
return f"Analysis failed: {str(e)}", "", "", "", "" | |
# Gradio interface | |
def create_interface(): | |
"""Create the Gradio interface""" | |
with gr.Blocks(title="Tree Analyzer", theme=gr.themes.Soft()) as demo: | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px;"> | |
<h1>π³ Tree Analyzer</h1> | |
<p>Upload an image of a tree and optionally provide GPS coordinates for comprehensive analysis</p> | |
</div> | |
""") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_input = gr.Image( | |
type="pil", | |
label="Upload Tree Image", | |
height=400 | |
) | |
with gr.Row(): | |
lat_input = gr.Number( | |
label="Latitude", | |
placeholder="e.g., 40.7128", | |
precision=6 | |
) | |
lon_input = gr.Number( | |
label="Longitude", | |
placeholder="e.g., -74.0060", | |
precision=6 | |
) | |
analyze_btn = gr.Button("π Analyze Tree", variant="primary", size="lg") | |
gr.HTML(""" | |
<div style="margin-top: 20px; padding: 15px; background-color: #f0f0f0; border-radius: 8px;"> | |
<h4>π How to get GPS coordinates:</h4> | |
<ul> | |
<li>Google Maps: Right-click location β Copy coordinates</li> | |
<li>Phone: Use GPS coordinate apps</li> | |
<li>Camera: Check photo metadata for GPS info</li> | |
</ul> | |
</div> | |
""") | |
with gr.Column(scale=2): | |
with gr.Tab("π Analysis Results"): | |
status_output = gr.Textbox( | |
label="Analysis Status", | |
interactive=False | |
) | |
height_output = gr.Textbox( | |
label="Height Estimation", | |
interactive=False, | |
lines=3 | |
) | |
species_output = gr.Markdown( | |
label="Species Identification", | |
height=300 | |
) | |
with gr.Tab("πΊοΈ Location"): | |
location_output = gr.Textbox( | |
label="Location Information", | |
interactive=False | |
) | |
map_output = gr.HTML( | |
label="Location Map", | |
height=400 | |
) | |
# Connect the analyze button | |
analyze_btn.click( | |
fn=analyze_tree, | |
inputs=[image_input, lat_input, lon_input], | |
outputs=[status_output, height_output, species_output, location_output, map_output] | |
) | |
gr.HTML(""" | |
<div style="text-align: center; padding: 20px; margin-top: 30px; border-top: 1px solid #ddd;"> | |
<p><strong>Features:</strong></p> | |
<p>π Species identification using AI models | π Height estimation via depth analysis | πΊοΈ Location mapping | π Wikipedia integration</p> | |
<p style="color: #666; font-size: 0.9em;"> | |
Note: This tool provides estimates and suggestions. For scientific purposes, consult with professional botanists or arborists. | |
</p> | |
</div> | |
""") | |
return demo | |
if __name__ == "__main__": | |
# Create and launch the interface | |
demo = create_interface() | |
demo.launch( | |
share=True, | |
debug=True, | |
show_error=True | |
) |