Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,102 +1,297 @@
|
|
1 |
-
# arboreal_gradio_app/app.py
|
2 |
import gradio as gr
|
3 |
-
import cv2
|
4 |
import torch
|
|
|
5 |
import numpy as np
|
6 |
-
import
|
|
|
|
|
|
|
|
|
7 |
import wikipedia
|
8 |
-
import
|
|
|
|
|
|
|
|
|
9 |
import os
|
10 |
-
from ultralytics import YOLO
|
11 |
-
from PIL import Image
|
12 |
-
from datetime import datetime
|
13 |
-
|
14 |
-
# Load YOLOv8 model
|
15 |
-
model = YOLO("yolov8n.pt") # You can replace with a custom-trained tree model
|
16 |
-
|
17 |
-
# Constants
|
18 |
-
KNOWN_PERSON_HEIGHT = 1.7 # meters
|
19 |
-
CSV_FILE = "tree_log.csv"
|
20 |
-
|
21 |
-
# Initialize CSV if it doesn't exist
|
22 |
-
if not os.path.exists(CSV_FILE):
|
23 |
-
pd.DataFrame(columns=["Image", "Height (m)", "Species", "Latitude", "Longitude", "Timestamp"]).to_csv(CSV_FILE, index=False)
|
24 |
-
|
25 |
-
# Function to estimate height using YOLO detections
|
26 |
-
def estimate_tree_height(img):
|
27 |
-
results = model(img)
|
28 |
-
person_box, tree_box = None, None
|
29 |
-
|
30 |
-
for result in results:
|
31 |
-
for box in result.boxes:
|
32 |
-
cls = result.names[int(box.cls[0])]
|
33 |
-
xyxy = box.xyxy[0].cpu().numpy()
|
34 |
-
if cls.lower() == "person" and person_box is None:
|
35 |
-
person_box = xyxy
|
36 |
-
elif cls.lower() == "tree" and tree_box is None:
|
37 |
-
tree_box = xyxy
|
38 |
-
|
39 |
-
if tree_box is None or person_box is None:
|
40 |
-
return None, "Could not detect both tree and person."
|
41 |
-
|
42 |
-
# Estimate height using pixel ratio
|
43 |
-
tree_h = tree_box[3] - tree_box[1]
|
44 |
-
person_h = person_box[3] - person_box[1]
|
45 |
-
ratio = tree_h / person_h
|
46 |
-
estimated_height = ratio * KNOWN_PERSON_HEIGHT
|
47 |
-
|
48 |
-
return estimated_height, None
|
49 |
-
|
50 |
-
# Function to get species info from Wikipedia
|
51 |
-
def get_species_info(species_name):
|
52 |
-
try:
|
53 |
-
summary = wikipedia.summary(species_name, sentences=2)
|
54 |
-
return summary
|
55 |
-
except:
|
56 |
-
return "No species info found."
|
57 |
-
|
58 |
-
# Main process function
|
59 |
-
def process(image, species_guess, latitude, longitude):
|
60 |
-
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
61 |
-
image_array = np.array(image)
|
62 |
-
|
63 |
-
height, error = estimate_tree_height(image_array)
|
64 |
-
if error:
|
65 |
-
return error, None
|
66 |
-
|
67 |
-
species_info = get_species_info(species_guess)
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
"
|
76 |
-
"
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
82 |
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
|
85 |
-
# Gradio
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
2 |
import torch
|
3 |
+
import cv2
|
4 |
import numpy as np
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
from PIL import Image
|
8 |
+
import transformers
|
9 |
+
from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
|
10 |
import wikipedia
|
11 |
+
import folium
|
12 |
+
from geopy.geocoders import Nominatim
|
13 |
+
import base64
|
14 |
+
from io import BytesIO
|
15 |
+
import tempfile
|
16 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
+
class TreeAnalyzer:
|
19 |
+
def __init__(self):
|
20 |
+
self.setup_models()
|
21 |
+
self.geolocator = Nominatim(user_agent="tree_analyzer")
|
22 |
+
|
23 |
+
def setup_models(self):
|
24 |
+
"""Initialize all required models"""
|
25 |
+
print("Loading models...")
|
26 |
+
|
27 |
+
# Load MiDaS model for depth estimation
|
28 |
+
try:
|
29 |
+
self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small')
|
30 |
+
self.midas.eval()
|
31 |
+
self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
|
32 |
+
self.transform = self.midas_transforms.small_transform
|
33 |
+
print("β MiDaS model loaded successfully")
|
34 |
+
except Exception as e:
|
35 |
+
print(f"Error loading MiDaS: {e}")
|
36 |
+
self.midas = None
|
37 |
+
|
38 |
+
# Load plant classification model
|
39 |
+
try:
|
40 |
+
self.plant_classifier = pipeline(
|
41 |
+
"image-classification",
|
42 |
+
model="microsoft/resnet-50",
|
43 |
+
return_top_k=3
|
44 |
+
)
|
45 |
+
print("β Plant classifier loaded successfully")
|
46 |
+
except Exception as e:
|
47 |
+
print(f"Error loading plant classifier: {e}")
|
48 |
+
# Fallback to a more specific plant model if available
|
49 |
+
try:
|
50 |
+
self.plant_classifier = pipeline(
|
51 |
+
"image-classification",
|
52 |
+
model="google/vit-base-patch16-224",
|
53 |
+
return_top_k=3
|
54 |
+
)
|
55 |
+
print("β Fallback classifier loaded successfully")
|
56 |
+
except:
|
57 |
+
self.plant_classifier = None
|
58 |
+
print("β Could not load plant classifier")
|
59 |
+
|
60 |
+
def estimate_tree_height(self, image, known_object_height=1.7):
|
61 |
+
"""
|
62 |
+
Estimate tree height using MiDaS depth estimation
|
63 |
+
known_object_height: assumed height of reference object (person = 1.7m)
|
64 |
+
"""
|
65 |
+
if self.midas is None:
|
66 |
+
return "MiDaS model not available", None
|
67 |
+
|
68 |
+
try:
|
69 |
+
# Convert PIL to OpenCV format
|
70 |
+
img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
|
71 |
+
|
72 |
+
# Prepare image for MiDaS
|
73 |
+
input_batch = self.transform(img_cv).to(torch.float32)
|
74 |
+
|
75 |
+
# Generate depth map
|
76 |
+
with torch.no_grad():
|
77 |
+
prediction = self.midas(input_batch)
|
78 |
+
prediction = torch.nn.functional.interpolate(
|
79 |
+
prediction.unsqueeze(1),
|
80 |
+
size=img_cv.shape[:2],
|
81 |
+
mode="bicubic",
|
82 |
+
align_corners=False,
|
83 |
+
).squeeze()
|
84 |
+
|
85 |
+
# Convert to numpy
|
86 |
+
depth_map = prediction.cpu().numpy()
|
87 |
+
|
88 |
+
# Normalize depth map for visualization
|
89 |
+
depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
|
90 |
+
depth_map_colored = cv2.applyColorMap(depth_map_normalized, cv2.COLORMAP_PLASMA)
|
91 |
+
|
92 |
+
# Simple height estimation (this is a simplified approach)
|
93 |
+
# In reality, you'd need camera calibration and more sophisticated methods
|
94 |
+
height, width = depth_map.shape
|
95 |
+
|
96 |
+
# Assume tree is in the center-upper portion of the image
|
97 |
+
tree_region = depth_map[int(height*0.1):int(height*0.8), int(width*0.3):int(width*0.7)]
|
98 |
+
|
99 |
+
# Calculate relative height based on depth variations
|
100 |
+
depth_range = np.max(tree_region) - np.min(tree_region)
|
101 |
+
|
102 |
+
# Rough estimation: scale based on depth range and image dimensions
|
103 |
+
estimated_height = (depth_range / np.max(depth_map)) * height * 0.02 # Scaling factor
|
104 |
+
estimated_height = max(2.0, min(50.0, estimated_height)) # Clamp between 2-50 meters
|
105 |
+
|
106 |
+
return f"Estimated height: {estimated_height:.1f} meters", depth_map_colored
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
return f"Error in height estimation: {str(e)}", None
|
110 |
+
|
111 |
+
def identify_tree_species(self, image):
|
112 |
+
"""Identify tree species using image classification"""
|
113 |
+
if self.plant_classifier is None:
|
114 |
+
return "Plant classifier not available", []
|
115 |
+
|
116 |
+
try:
|
117 |
+
# Get predictions
|
118 |
+
predictions = self.plant_classifier(image)
|
119 |
+
|
120 |
+
# Filter for tree-related predictions
|
121 |
+
tree_keywords = ['tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow', 'cherry', 'apple']
|
122 |
+
tree_predictions = []
|
123 |
+
|
124 |
+
for pred in predictions:
|
125 |
+
label = pred['label'].lower()
|
126 |
+
if any(keyword in label for keyword in tree_keywords):
|
127 |
+
tree_predictions.append(pred)
|
128 |
+
|
129 |
+
if not tree_predictions:
|
130 |
+
tree_predictions = predictions[:2] # Take top 2 if no tree-specific matches
|
131 |
+
|
132 |
+
# Get Wikipedia information for top prediction
|
133 |
+
species_info = []
|
134 |
+
for pred in tree_predictions[:2]:
|
135 |
+
try:
|
136 |
+
# Search Wikipedia
|
137 |
+
wiki_results = wikipedia.search(pred['label'] + " tree", results=1)
|
138 |
+
if wiki_results:
|
139 |
+
page = wikipedia.page(wiki_results[0])
|
140 |
+
summary = wikipedia.summary(wiki_results[0], sentences=2)
|
141 |
+
species_info.append({
|
142 |
+
'species': pred['label'],
|
143 |
+
'confidence': pred['score'],
|
144 |
+
'wiki_title': page.title,
|
145 |
+
'summary': summary,
|
146 |
+
'url': page.url
|
147 |
+
})
|
148 |
+
except:
|
149 |
+
species_info.append({
|
150 |
+
'species': pred['label'],
|
151 |
+
'confidence': pred['score'],
|
152 |
+
'wiki_title': 'Not found',
|
153 |
+
'summary': 'No Wikipedia information available',
|
154 |
+
'url': None
|
155 |
+
})
|
156 |
+
|
157 |
+
return "Species identified successfully", species_info
|
158 |
+
|
159 |
+
except Exception as e:
|
160 |
+
return f"Error in species identification: {str(e)}", []
|
161 |
+
|
162 |
+
def get_location_info(self, latitude, longitude):
|
163 |
+
"""Get location information from coordinates"""
|
164 |
+
if latitude is None or longitude is None:
|
165 |
+
return "Location not provided", None
|
166 |
+
|
167 |
+
try:
|
168 |
+
location = self.geolocator.reverse(f"{latitude}, {longitude}")
|
169 |
+
|
170 |
+
# Create a map
|
171 |
+
m = folium.Map(location=[latitude, longitude], zoom_start=15)
|
172 |
+
folium.Marker(
|
173 |
+
[latitude, longitude],
|
174 |
+
popup=f"Tree Location<br>Lat: {latitude:.6f}<br>Lon: {longitude:.6f}",
|
175 |
+
tooltip="Tree Location"
|
176 |
+
).add_to(m)
|
177 |
+
|
178 |
+
# Save map to temporary file
|
179 |
+
map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html')
|
180 |
+
m.save(map_file.name)
|
181 |
+
|
182 |
+
address = location.address if location else "Address not found"
|
183 |
+
|
184 |
+
return f"Location: {address}", map_file.name
|
185 |
+
|
186 |
+
except Exception as e:
|
187 |
+
return f"Error getting location info: {str(e)}", None
|
188 |
|
189 |
+
def analyze_tree(image, latitude, longitude):
|
190 |
+
"""Main function to analyze tree from image and location"""
|
191 |
+
if image is None:
|
192 |
+
return "Please upload an image", "", "", "", None
|
193 |
+
|
194 |
+
analyzer = TreeAnalyzer()
|
195 |
+
|
196 |
+
# Analyze height
|
197 |
+
height_result, depth_map = analyzer.estimate_tree_height(image)
|
198 |
+
|
199 |
+
# Identify species
|
200 |
+
species_result, species_info = analyzer.identify_tree_species(image)
|
201 |
+
|
202 |
+
# Get location info
|
203 |
+
location_result, map_file = analyzer.get_location_info(latitude, longitude)
|
204 |
+
|
205 |
+
# Format species information
|
206 |
+
species_text = ""
|
207 |
+
if species_info:
|
208 |
+
for info in species_info:
|
209 |
+
species_text += f"**{info['species']}** (Confidence: {info['confidence']:.2f})\n"
|
210 |
+
species_text += f"*{info['summary']}*\n"
|
211 |
+
if info['url']:
|
212 |
+
species_text += f"[Wikipedia Link]({info['url']})\n"
|
213 |
+
species_text += "\n"
|
214 |
+
|
215 |
+
# Return results
|
216 |
+
return (
|
217 |
+
height_result,
|
218 |
+
species_result,
|
219 |
+
species_text,
|
220 |
+
location_result,
|
221 |
+
map_file
|
222 |
+
)
|
223 |
|
224 |
+
# Create Gradio interface
|
225 |
+
def create_interface():
|
226 |
+
with gr.Blocks(title="Tree Analysis App", theme=gr.themes.Soft()) as demo:
|
227 |
+
gr.Markdown("""
|
228 |
+
# π³ Tree Analysis App
|
229 |
+
|
230 |
+
Upload an image of a tree and optionally provide GPS coordinates to get:
|
231 |
+
- **Tree height estimation** using MiDaS depth estimation
|
232 |
+
- **Species identification** with Wikipedia information
|
233 |
+
- **Location mapping** of where the tree was captured
|
234 |
+
""")
|
235 |
+
|
236 |
+
with gr.Row():
|
237 |
+
with gr.Column(scale=1):
|
238 |
+
image_input = gr.Image(type="pil", label="Upload Tree Image")
|
239 |
+
|
240 |
+
with gr.Row():
|
241 |
+
latitude_input = gr.Number(
|
242 |
+
label="Latitude (optional)",
|
243 |
+
placeholder="e.g., 40.7128",
|
244 |
+
info="GPS latitude coordinate"
|
245 |
+
)
|
246 |
+
longitude_input = gr.Number(
|
247 |
+
label="Longitude (optional)",
|
248 |
+
placeholder="e.g., -74.0060",
|
249 |
+
info="GPS longitude coordinate"
|
250 |
+
)
|
251 |
+
|
252 |
+
analyze_btn = gr.Button("π Analyze Tree", variant="primary")
|
253 |
+
|
254 |
+
with gr.Column(scale=2):
|
255 |
+
with gr.Tab("Results"):
|
256 |
+
height_output = gr.Textbox(label="Height Estimation", lines=2)
|
257 |
+
species_status = gr.Textbox(label="Species Identification Status", lines=1)
|
258 |
+
species_output = gr.Markdown(label="Species Information")
|
259 |
+
location_output = gr.Textbox(label="Location Information", lines=2)
|
260 |
+
|
261 |
+
with gr.Tab("Location Map"):
|
262 |
+
map_output = gr.HTML(label="Location Map")
|
263 |
+
|
264 |
+
# Connect the analyze button
|
265 |
+
analyze_btn.click(
|
266 |
+
fn=analyze_tree,
|
267 |
+
inputs=[image_input, latitude_input, longitude_input],
|
268 |
+
outputs=[height_output, species_status, species_output, location_output, map_output]
|
269 |
+
)
|
270 |
+
|
271 |
+
# Example section
|
272 |
+
gr.Markdown("""
|
273 |
+
## π± Usage Tips:
|
274 |
+
1. **Take a clear photo** of the tree with good lighting
|
275 |
+
2. **Include reference objects** (like people) for better height estimation
|
276 |
+
3. **Enable GPS** on your phone and note the coordinates
|
277 |
+
4. **Upload the image** and enter GPS coordinates if available
|
278 |
+
5. **Click Analyze** to get comprehensive tree information
|
279 |
+
|
280 |
+
## π Sharing:
|
281 |
+
This app generates a shareable link that you can send to others!
|
282 |
+
""")
|
283 |
+
|
284 |
+
return demo
|
285 |
|
286 |
+
# Main execution
|
287 |
+
if __name__ == "__main__":
|
288 |
+
# Create and launch the interface
|
289 |
+
demo = create_interface()
|
290 |
+
|
291 |
+
# Launch with sharing enabled to generate a public link
|
292 |
+
demo.launch(
|
293 |
+
share=True, # This creates a shareable public link
|
294 |
+
server_name="0.0.0.0",
|
295 |
+
server_port=7860,
|
296 |
+
show_error=True
|
297 |
+
)
|