pranaya20 commited on
Commit
eb93675
Β·
verified Β·
1 Parent(s): aa4f4c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +199 -220
app.py CHANGED
@@ -2,296 +2,275 @@ import gradio as gr
2
  import torch
3
  import cv2
4
  import numpy as np
5
- import requests
6
- import json
7
  from PIL import Image
8
- import transformers
9
  from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
10
  import wikipedia
11
  import folium
12
- from geopy.geocoders import Nominatim
13
- import base64
14
- from io import BytesIO
15
  import tempfile
16
  import os
 
 
 
 
 
 
 
17
 
18
  class TreeAnalyzer:
19
  def __init__(self):
20
  self.setup_models()
21
- self.geolocator = Nominatim(user_agent="tree_analyzer")
22
 
23
  def setup_models(self):
24
- """Initialize all required models"""
25
- print("Loading models...")
26
 
27
- # Load MiDaS model for depth estimation
 
28
  try:
29
- self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small')
30
  self.midas.eval()
31
- self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms')
32
  self.transform = self.midas_transforms.small_transform
33
- print("βœ“ MiDaS model loaded successfully")
34
  except Exception as e:
35
- print(f"Error loading MiDaS: {e}")
36
- self.midas = None
37
 
38
  # Load plant classification model
39
- try:
40
- self.plant_classifier = pipeline(
41
- "image-classification",
42
- model="microsoft/resnet-50",
43
- return_top_k=3
44
- )
45
- print("βœ“ Plant classifier loaded successfully")
46
- except Exception as e:
47
- print(f"Error loading plant classifier: {e}")
48
- # Fallback to a more specific plant model if available
49
  try:
50
  self.plant_classifier = pipeline(
51
  "image-classification",
52
- model="google/vit-base-patch16-224",
53
- return_top_k=3
54
  )
55
- print("βœ“ Fallback classifier loaded successfully")
56
- except:
57
- self.plant_classifier = None
58
- print("βœ— Could not load plant classifier")
 
59
 
60
- def estimate_tree_height(self, image, known_object_height=1.7):
61
- """
62
- Estimate tree height using MiDaS depth estimation
63
- known_object_height: assumed height of reference object (person = 1.7m)
64
- """
65
  if self.midas is None:
66
- return "MiDaS model not available", None
67
 
68
  try:
69
- # Convert PIL to OpenCV format
70
  img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
71
 
72
- # Prepare image for MiDaS
73
- input_batch = self.transform(img_cv).to(torch.float32)
74
 
75
- # Generate depth map
76
  with torch.no_grad():
77
  prediction = self.midas(input_batch)
78
  prediction = torch.nn.functional.interpolate(
79
  prediction.unsqueeze(1),
80
- size=img_cv.shape[:2],
81
  mode="bicubic",
82
  align_corners=False,
83
  ).squeeze()
84
 
85
- # Convert to numpy
86
  depth_map = prediction.cpu().numpy()
87
 
88
- # Normalize depth map for visualization
89
- depth_map_normalized = cv2.normalize(depth_map, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
90
- depth_map_colored = cv2.applyColorMap(depth_map_normalized, cv2.COLORMAP_PLASMA)
91
-
92
- # Simple height estimation (this is a simplified approach)
93
- # In reality, you'd need camera calibration and more sophisticated methods
94
- height, width = depth_map.shape
95
-
96
- # Assume tree is in the center-upper portion of the image
97
- tree_region = depth_map[int(height*0.1):int(height*0.8), int(width*0.3):int(width*0.7)]
98
-
99
- # Calculate relative height based on depth variations
100
- depth_range = np.max(tree_region) - np.min(tree_region)
101
-
102
- # Rough estimation: scale based on depth range and image dimensions
103
- estimated_height = (depth_range / np.max(depth_map)) * height * 0.02 # Scaling factor
104
- estimated_height = max(2.0, min(50.0, estimated_height)) # Clamp between 2-50 meters
105
-
106
- return f"Estimated height: {estimated_height:.1f} meters", depth_map_colored
107
 
 
 
 
 
 
 
 
 
 
108
  except Exception as e:
109
- return f"Error in height estimation: {str(e)}", None
 
110
 
111
  def identify_tree_species(self, image):
112
- """Identify tree species using image classification"""
113
  if self.plant_classifier is None:
114
- return "Plant classifier not available", []
115
 
116
  try:
 
 
 
 
117
  # Get predictions
118
  predictions = self.plant_classifier(image)
119
 
120
- # Filter for tree-related predictions
121
- tree_keywords = ['tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow', 'cherry', 'apple']
122
- tree_predictions = []
 
 
 
 
 
 
 
 
 
 
 
 
123
 
 
 
124
  for pred in predictions:
125
  label = pred['label'].lower()
126
- if any(keyword in label for keyword in tree_keywords):
127
- tree_predictions.append(pred)
128
-
129
- if not tree_predictions:
130
- tree_predictions = predictions[:2] # Take top 2 if no tree-specific matches
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- # Get Wikipedia information for top prediction
133
- species_info = []
134
- for pred in tree_predictions[:2]:
135
- try:
136
- # Search Wikipedia
137
- wiki_results = wikipedia.search(pred['label'] + " tree", results=1)
138
- if wiki_results:
139
- page = wikipedia.page(wiki_results[0])
140
- summary = wikipedia.summary(wiki_results[0], sentences=2)
141
- species_info.append({
142
- 'species': pred['label'],
143
- 'confidence': pred['score'],
144
- 'wiki_title': page.title,
145
- 'summary': summary,
146
- 'url': page.url
147
- })
148
- except:
149
- species_info.append({
150
- 'species': pred['label'],
151
- 'confidence': pred['score'],
152
- 'wiki_title': 'Not found',
153
- 'summary': 'No Wikipedia information available',
154
- 'url': None
155
- })
156
 
157
- return "Species identified successfully", species_info
 
158
 
 
 
 
 
 
159
  except Exception as e:
160
- return f"Error in species identification: {str(e)}", []
 
161
 
162
- def get_location_info(self, latitude, longitude):
163
- """Get location information from coordinates"""
164
- if latitude is None or longitude is None:
165
- return "Location not provided", None
166
-
167
  try:
168
- location = self.geolocator.reverse(f"{latitude}, {longitude}")
 
169
 
170
- # Create a map
171
- m = folium.Map(location=[latitude, longitude], zoom_start=15)
172
- folium.Marker(
173
- [latitude, longitude],
174
- popup=f"Tree Location<br>Lat: {latitude:.6f}<br>Lon: {longitude:.6f}",
175
- tooltip="Tree Location"
176
- ).add_to(m)
177
 
178
- # Save map to temporary file
179
- map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html')
180
- m.save(map_file.name)
181
-
182
- address = location.address if location else "Address not found"
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
- return f"Location: {address}", map_file.name
 
 
 
 
185
 
186
  except Exception as e:
187
- return f"Error getting location info: {str(e)}", None
 
 
 
 
188
 
189
  def analyze_tree(image, latitude, longitude):
190
- """Main function to analyze tree from image and location"""
191
  if image is None:
192
- return "Please upload an image", "", "", "", None
193
-
194
- analyzer = TreeAnalyzer()
195
-
196
- # Analyze height
197
- height_result, depth_map = analyzer.estimate_tree_height(image)
198
-
199
- # Identify species
200
- species_result, species_info = analyzer.identify_tree_species(image)
201
 
202
- # Get location info
203
- location_result, map_file = analyzer.get_location_info(latitude, longitude)
204
-
205
- # Format species information
206
- species_text = ""
207
- if species_info:
208
- for info in species_info:
209
- species_text += f"**{info['species']}** (Confidence: {info['confidence']:.2f})\n"
210
- species_text += f"*{info['summary']}*\n"
211
- if info['url']:
212
- species_text += f"[Wikipedia Link]({info['url']})\n"
213
- species_text += "\n"
214
-
215
- # Return results
216
- return (
217
- height_result,
218
- species_result,
219
- species_text,
220
- location_result,
221
- map_file
222
- )
223
-
224
- # Create Gradio interface
225
- def create_interface():
226
- with gr.Blocks(title="Tree Analysis App", theme=gr.themes.Soft()) as demo:
227
- gr.Markdown("""
228
- # 🌳 Tree Analysis App
229
 
230
- Upload an image of a tree and optionally provide GPS coordinates to get:
231
- - **Tree height estimation** using MiDaS depth estimation
232
- - **Species identification** with Wikipedia information
233
- - **Location mapping** of where the tree was captured
234
- """)
235
 
236
- with gr.Row():
237
- with gr.Column(scale=1):
238
- image_input = gr.Image(type="pil", label="Upload Tree Image")
239
-
240
- with gr.Row():
241
- latitude_input = gr.Number(
242
- label="Latitude (optional)",
243
- placeholder="e.g., 40.7128",
244
- info="GPS latitude coordinate"
245
- )
246
- longitude_input = gr.Number(
247
- label="Longitude (optional)",
248
- placeholder="e.g., -74.0060",
249
- info="GPS longitude coordinate"
250
- )
251
-
252
- analyze_btn = gr.Button("πŸ” Analyze Tree", variant="primary")
253
-
254
- with gr.Column(scale=2):
255
- with gr.Tab("Results"):
256
- height_output = gr.Textbox(label="Height Estimation", lines=2)
257
- species_status = gr.Textbox(label="Species Identification Status", lines=1)
258
- species_output = gr.Markdown(label="Species Information")
259
- location_output = gr.Textbox(label="Location Information", lines=2)
260
-
261
- with gr.Tab("Location Map"):
262
- map_output = gr.HTML(label="Location Map")
263
 
264
- # Connect the analyze button
265
- analyze_btn.click(
266
- fn=analyze_tree,
267
- inputs=[image_input, latitude_input, longitude_input],
268
- outputs=[height_output, species_status, species_output, location_output, map_output]
269
- )
 
 
 
 
 
 
 
 
 
 
270
 
271
- # Example section
272
- gr.Markdown("""
273
- ## πŸ“± Usage Tips:
274
- 1. **Take a clear photo** of the tree with good lighting
275
- 2. **Include reference objects** (like people) for better height estimation
276
- 3. **Enable GPS** on your phone and note the coordinates
277
- 4. **Upload the image** and enter GPS coordinates if available
278
- 5. **Click Analyze** to get comprehensive tree information
279
 
280
- ## πŸ”— Sharing:
281
- This app generates a shareable link that you can send to others!
282
- """)
283
-
284
- return demo
285
-
286
- # Main execution
287
- if __name__ == "__main__":
288
- # Create and launch the interface
289
- demo = create_interface()
290
-
291
- # Launch with sharing enabled to generate a public link
292
- demo.launch(
293
- share=True, # This creates a shareable public link
294
- server_name="0.0.0.0",
295
- server_port=7860,
296
- show_error=True
297
- )
 
 
 
 
 
 
 
 
 
 
2
  import torch
3
  import cv2
4
  import numpy as np
 
 
5
  from PIL import Image
 
6
  from transformers import pipeline, AutoImageProcessor, AutoModelForImageClassification
7
  import wikipedia
8
  import folium
 
 
 
9
  import tempfile
10
  import os
11
+ import logging
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger(__name__)
18
 
19
  class TreeAnalyzer:
20
  def __init__(self):
21
  self.setup_models()
 
22
 
23
  def setup_models(self):
24
+ """Initialize models optimized for HF Spaces"""
25
+ logger.info("Loading models for HF Spaces...")
26
 
27
+ # Load depth estimation model
28
+ self.midas = None
29
  try:
30
+ self.midas = torch.hub.load('intel-isl/MiDaS', 'MiDaS_small', trust_repo=True)
31
  self.midas.eval()
32
+ self.midas_transforms = torch.hub.load('intel-isl/MiDaS', 'transforms', trust_repo=True)
33
  self.transform = self.midas_transforms.small_transform
34
+ logger.info("βœ“ MiDaS loaded")
35
  except Exception as e:
36
+ logger.error(f"MiDaS failed: {e}")
 
37
 
38
  # Load plant classification model
39
+ self.plant_classifier = None
40
+ models_to_try = [
41
+ "google/vit-base-patch16-224",
42
+ "microsoft/resnet-50",
43
+ "facebook/convnext-tiny-224"
44
+ ]
45
+
46
+ for model_name in models_to_try:
 
 
47
  try:
48
  self.plant_classifier = pipeline(
49
  "image-classification",
50
+ model=model_name,
51
+ return_top_k=10
52
  )
53
+ logger.info(f"βœ“ Loaded classifier: {model_name}")
54
+ break
55
+ except Exception as e:
56
+ logger.warning(f"Failed to load {model_name}: {e}")
57
+ continue
58
 
59
+ def estimate_tree_height(self, image):
60
+ """Estimate tree height using depth estimation"""
 
 
 
61
  if self.midas is None:
62
+ return "Height estimation not available (MiDaS model failed to load)"
63
 
64
  try:
65
+ # Convert and resize image
66
  img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
67
+ h, w = img_cv.shape[:2]
68
+
69
+ # Resize for memory efficiency
70
+ if h > 384 or w > 384:
71
+ scale = min(384/h, 384/w)
72
+ new_h, new_w = int(h*scale), int(w*scale)
73
+ img_cv = cv2.resize(img_cv, (new_w, new_h))
74
 
75
+ # Process with MiDaS
76
+ input_batch = self.transform(img_cv)
77
 
 
78
  with torch.no_grad():
79
  prediction = self.midas(input_batch)
80
  prediction = torch.nn.functional.interpolate(
81
  prediction.unsqueeze(1),
82
+ size=(img_cv.shape[0], img_cv.shape[1]),
83
  mode="bicubic",
84
  align_corners=False,
85
  ).squeeze()
86
 
 
87
  depth_map = prediction.cpu().numpy()
88
 
89
+ # Simple height estimation
90
+ h_img, w_img = depth_map.shape
91
+ center_region = depth_map[h_img//4:3*h_img//4, w_img//4:3*w_img//4]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ if center_region.size > 0:
94
+ depth_range = np.max(center_region) - np.min(center_region)
95
+ height_ratio = center_region.shape[0] / h_img
96
+ estimated_height = max(1.5, min(50.0, (depth_range * height_ratio * 30)))
97
+
98
+ return f"Estimated height: {estimated_height:.1f} meters\n(Approximate estimate based on image depth analysis)"
99
+ else:
100
+ return "Could not estimate height from this image"
101
+
102
  except Exception as e:
103
+ logger.error(f"Height estimation error: {e}")
104
+ return f"Height estimation failed: {str(e)}"
105
 
106
  def identify_tree_species(self, image):
107
+ """Identify tree species with better filtering"""
108
  if self.plant_classifier is None:
109
+ return "Species identification not available (classifier failed to load)", []
110
 
111
  try:
112
+ # Resize image for processing
113
+ if image.size[0] > 224 or image.size[1] > 224:
114
+ image = image.resize((224, 224), Image.Resampling.LANCZOS)
115
+
116
  # Get predictions
117
  predictions = self.plant_classifier(image)
118
 
119
+ # Enhanced plant/tree keywords
120
+ plant_keywords = [
121
+ # Trees
122
+ 'tree', 'oak', 'pine', 'maple', 'birch', 'cedar', 'fir', 'palm', 'willow',
123
+ 'cherry', 'apple', 'spruce', 'poplar', 'ash', 'elm', 'beech', 'sycamore',
124
+ 'acacia', 'eucalyptus', 'magnolia', 'chestnut', 'walnut', 'hickory',
125
+ 'cypress', 'juniper', 'redwood', 'bamboo', 'mahogany', 'teak',
126
+ # Plants and botanical terms
127
+ 'plant', 'leaf', 'leaves', 'branch', 'bark', 'forest', 'wood', 'botanical',
128
+ 'flora', 'foliage', 'evergreen', 'deciduous', 'conifer', 'hardwood',
129
+ 'softwood', 'timber', 'shrub', 'bush', 'vine', 'fern', 'moss',
130
+ # Specific species indicators
131
+ 'quercus', 'pinus', 'acer', 'betula', 'fagus', 'tilia', 'fraxinus',
132
+ 'platanus', 'castanea', 'juglans', 'carya', 'ulmus', 'salix'
133
+ ]
134
 
135
+ # Process and score predictions
136
+ species_candidates = []
137
  for pred in predictions:
138
  label = pred['label'].lower()
139
+ confidence = pred['score']
140
+
141
+ # Calculate plant relevance score
142
+ plant_score = sum(1 for keyword in plant_keywords if keyword in label)
143
+ is_plant_related = plant_score > 0
144
+
145
+ # Get Wikipedia info
146
+ wiki_info = self.get_wikipedia_info(pred['label'])
147
+
148
+ species_candidates.append({
149
+ 'species': pred['label'],
150
+ 'confidence': confidence,
151
+ 'plant_score': plant_score,
152
+ 'is_plant_related': is_plant_related,
153
+ 'wiki_info': wiki_info
154
+ })
155
 
156
+ # Sort by plant relevance and confidence
157
+ species_candidates.sort(key=lambda x: (x['plant_score'], x['confidence']), reverse=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Return top candidates
160
+ final_results = species_candidates[:3]
161
 
162
+ if any(result['is_plant_related'] for result in final_results):
163
+ return "Species identification completed", final_results
164
+ else:
165
+ return "Possible species identified (may not be plants)", final_results
166
+
167
  except Exception as e:
168
+ logger.error(f"Species identification error: {e}")
169
+ return f"Species identification failed: {str(e)}", []
170
 
171
+ def get_wikipedia_info(self, species_name):
172
+ """Get Wikipedia information with better error handling"""
 
 
 
173
  try:
174
+ # Clean species name
175
+ clean_name = species_name.split(',')[0].split('(')[0].strip()
176
 
177
+ search_queries = [
178
+ clean_name,
179
+ f"{clean_name} tree",
180
+ f"{clean_name} plant",
181
+ f"{clean_name} species"
182
+ ]
 
183
 
184
+ for query in search_queries:
185
+ try:
186
+ results = wikipedia.search(query, results=2)
187
+ if results:
188
+ for result in results:
189
+ try:
190
+ page = wikipedia.page(result, auto_suggest=False)
191
+ summary = wikipedia.summary(result, sentences=2, auto_suggest=False)
192
+ return {
193
+ 'title': page.title,
194
+ 'summary': summary,
195
+ 'url': page.url
196
+ }
197
+ except:
198
+ continue
199
+ except:
200
+ continue
201
 
202
+ return {
203
+ 'title': 'No information found',
204
+ 'summary': f'Wikipedia information not available for {species_name}',
205
+ 'url': None
206
+ }
207
 
208
  except Exception as e:
209
+ return {
210
+ 'title': 'Error',
211
+ 'summary': f'Could not retrieve information: {str(e)}',
212
+ 'url': None
213
+ }
214
 
215
  def analyze_tree(image, latitude, longitude):
216
+ """Main analysis function"""
217
  if image is None:
218
+ return "Please upload an image", "", "", "", ""
 
 
 
 
 
 
 
 
219
 
220
+ try:
221
+ analyzer = TreeAnalyzer()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
+ # Height estimation
224
+ height_result = analyzer.estimate_tree_height(image)
 
 
 
225
 
226
+ # Species identification
227
+ species_status, species_info = analyzer.identify_tree_species(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
+ # Format species results
230
+ species_text = ""
231
+ if species_info:
232
+ for i, info in enumerate(species_info, 1):
233
+ species_text += f"## {i}. {info['species']}\n"
234
+ species_text += f"**Confidence:** {info['confidence']:.3f}\n"
235
+ species_text += f"**Plant-related:** {'Yes' if info['is_plant_related'] else 'Uncertain'}\n"
236
+
237
+ wiki = info['wiki_info']
238
+ species_text += f"**Wikipedia:** {wiki['title']}\n"
239
+ species_text += f"{wiki['summary']}\n"
240
+ if wiki['url']:
241
+ species_text += f"πŸ”— [Read more]({wiki['url']})\n"
242
+ species_text += "\n---\n"
243
+ else:
244
+ species_text = "No species information could be determined from this image."
245
 
246
+ # Location info
247
+ location_result = ""
248
+ map_html = ""
 
 
 
 
 
249
 
250
+ if latitude is not None and longitude is not None:
251
+ try:
252
+ location_result = f"Coordinates: {latitude:.6f}, {longitude:.6f}"
253
+
254
+ # Create map
255
+ m = folium.Map(location=[latitude, longitude], zoom_start=15)
256
+ folium.Marker(
257
+ [latitude, longitude],
258
+ popup=f"Tree Location<br>{latitude:.6f}, {longitude:.6f}",
259
+ tooltip="Tree Location"
260
+ ).add_to(m)
261
+
262
+ # Save map
263
+ map_file = tempfile.NamedTemporaryFile(delete=False, suffix='.html', mode='w')
264
+ m.save(map_file.name)
265
+ map_file.close()
266
+
267
+ with open(map_file.name, 'r', encoding='utf-8') as f:
268
+ map_html = f.read()
269
+ os.unlink(map_file.name)
270
+
271
+ except Exception as e:
272
+ location_result = f"Error processing location: {str(e)}"
273
+ map_html = "<p>Could not generate map</p>"
274
+ else:
275
+ location_result = "No GPS coordinates provided"
276
+ map_html = "<p>