pranaya20 commited on
Commit
a261c62
·
verified ·
1 Parent(s): 0e4ab7b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -120
app.py CHANGED
@@ -1,128 +1,102 @@
 
1
  import gradio as gr
2
- from ultralytics import YOLO
3
- import torch
4
  import cv2
 
5
  import numpy as np
 
 
6
  import pandas as pd
7
- import h3
8
- import folium
9
- from folium.plugins import MarkerCluster
10
- from PIL import Image
11
- import base64
12
  import os
 
 
13
  from datetime import datetime
14
 
15
- # Load YOLO model
16
- model = YOLO("yolov8n.pt")
17
-
18
- # Try loading MiDaS depth estimation model
19
- try:
20
- midas = torch.hub.load("intel-isl/MiDaS", "MiDaS_small", trust_repo=True)
21
- midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms", trust_repo=True)
22
- transform = midas_transforms.small_transform
23
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
- midas.to(device)
25
- height_estimation_available = True
26
- except:
27
- height_estimation_available = False
28
-
29
- # Placeholder species classifier (could be replaced by real model)
30
- def classify_species(image):
31
- return "Neem Tree" # Dummy result
32
-
33
- # Create output CSV if it doesn't exist
34
- if not os.path.exists("tree_data.csv"):
35
- pd.DataFrame(columns=["timestamp", "latitude", "longitude", "h3_index", "species", "height_m", "image_path"]).to_csv("tree_data.csv", index=False)
36
-
37
- # Tree analysis function
38
- def analyze_tree(image, latitude, longitude):
39
- image_np = np.array(image)
40
-
41
- # Detect tree using YOLO
42
- results = model(image_np)[0]
43
- boxes = results.boxes.xyxy.cpu().numpy().astype(int)
44
- classes = results.boxes.cls.cpu().numpy().astype(int)
45
-
46
- height_m = "N/A"
47
-
48
- # Depth estimation (if available)
49
- if height_estimation_available:
50
- input_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
51
- input_image = cv2.resize(input_image, (256, 256))
52
- input_tensor = transform(input_image).to(device)
53
- with torch.no_grad():
54
- prediction = midas(input_tensor.unsqueeze(0))
55
- prediction = torch.nn.functional.interpolate(
56
- prediction.unsqueeze(1),
57
- size=image_np.shape[:2],
58
- mode="bicubic",
59
- align_corners=False,
60
- ).squeeze()
61
- depth_map = prediction.cpu().numpy()
62
- height_m = round(float(np.median(depth_map)), 2)
63
-
64
- # Species classification (placeholder)
65
- species = classify_species(image)
66
-
67
- # H3 geolocation
68
- h3_index = h3.geo_to_h3(float(latitude), float(longitude), 9)
69
-
70
- # Save image
71
- timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
72
- image_path = f"tree_{timestamp}.jpg"
73
- image.save(image_path)
74
-
75
- # Append to CSV
76
- df = pd.read_csv("tree_data.csv")
77
- df.loc[len(df)] = [timestamp, latitude, longitude, h3_index, species, height_m, image_path]
78
- df.to_csv("tree_data.csv", index=False)
79
-
80
- return f"Species: {species}\nEstimated Height: {height_m} meters\nH3 Index: {h3_index}"
81
-
82
- # Generate map with markers
83
- def generate_map():
84
- df = pd.read_csv("tree_data.csv")
85
- m = folium.Map(location=[20.5937, 78.9629], zoom_start=5)
86
- marker_cluster = MarkerCluster().add_to(m)
87
-
88
- for _, row in df.iterrows():
89
- folium.Marker(
90
- location=[row["latitude"], row["longitude"]],
91
- popup=f"Species: {row['species']}\nHeight: {row['height_m']}\nH3: {row['h3_index']}",
92
- ).add_to(marker_cluster)
93
-
94
- map_path = "tree_map.html"
95
- m.save(map_path)
96
- with open(map_path, "r") as f:
97
- html_content = f.read()
98
- return html_content
99
-
100
- # Gradio interface
101
- def js_location():
102
- return """
103
- async () => {
104
- const pos = await new Promise((resolve, reject) => {
105
- navigator.geolocation.getCurrentPosition(resolve, reject);
106
- });
107
- return [pos.coords.latitude, pos.coords.longitude];
108
  }
109
- """
110
-
111
- with gr.Blocks() as demo:
112
- with gr.Row():
113
- with gr.Column():
114
- cam = gr.Image(source="webcam", streaming=True)
115
- loc_btn = gr.Button("Get Location")
116
- lat = gr.Textbox(label="Latitude")
117
- lon = gr.Textbox(label="Longitude")
118
- analyze_btn = gr.Button("Analyze Tree")
119
- output = gr.Textbox(label="Results")
120
- with gr.Column():
121
- map_output = gr.HTML()
122
-
123
- loc_btn.click(None, js=js_location(), outputs=[lat, lon])
124
- analyze_btn.click(fn=analyze_tree, inputs=[cam, lat, lon], outputs=output)
125
- analyze_btn.click(fn=generate_map, outputs=map_output)
126
-
127
- if __name__ == "__main__":
128
- demo.launch()
 
 
 
 
1
+ # arboreal_gradio_app/app.py
2
  import gradio as gr
 
 
3
  import cv2
4
+ import torch
5
  import numpy as np
6
+ import time
7
+ import wikipedia
8
  import pandas as pd
 
 
 
 
 
9
  import os
10
+ from ultralytics import YOLO
11
+ from PIL import Image
12
  from datetime import datetime
13
 
14
+ # Load YOLOv8 model
15
+ model = YOLO("yolov8n.pt") # You can replace with a custom-trained tree model
16
+
17
+ # Constants
18
+ KNOWN_PERSON_HEIGHT = 1.7 # meters
19
+ CSV_FILE = "tree_log.csv"
20
+
21
+ # Initialize CSV if it doesn't exist
22
+ if not os.path.exists(CSV_FILE):
23
+ pd.DataFrame(columns=["Image", "Height (m)", "Species", "Latitude", "Longitude", "Timestamp"]).to_csv(CSV_FILE, index=False)
24
+
25
+ # Function to estimate height using YOLO detections
26
+ def estimate_tree_height(img):
27
+ results = model(img)
28
+ person_box, tree_box = None, None
29
+
30
+ for result in results:
31
+ for box in result.boxes:
32
+ cls = result.names[int(box.cls[0])]
33
+ xyxy = box.xyxy[0].cpu().numpy()
34
+ if cls.lower() == "person" and person_box is None:
35
+ person_box = xyxy
36
+ elif cls.lower() == "tree" and tree_box is None:
37
+ tree_box = xyxy
38
+
39
+ if tree_box is None or person_box is None:
40
+ return None, "Could not detect both tree and person."
41
+
42
+ # Estimate height using pixel ratio
43
+ tree_h = tree_box[3] - tree_box[1]
44
+ person_h = person_box[3] - person_box[1]
45
+ ratio = tree_h / person_h
46
+ estimated_height = ratio * KNOWN_PERSON_HEIGHT
47
+
48
+ return estimated_height, None
49
+
50
+ # Function to get species info from Wikipedia
51
+ def get_species_info(species_name):
52
+ try:
53
+ summary = wikipedia.summary(species_name, sentences=2)
54
+ return summary
55
+ except:
56
+ return "No species info found."
57
+
58
+ # Main process function
59
+ def process(image, species_guess, latitude, longitude):
60
+ timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
61
+ image_array = np.array(image)
62
+
63
+ height, error = estimate_tree_height(image_array)
64
+ if error:
65
+ return error, None
66
+
67
+ species_info = get_species_info(species_guess)
68
+
69
+ # Save to CSV
70
+ df = pd.read_csv(CSV_FILE)
71
+ image_filename = f"image_{int(time.time())}.jpg"
72
+ new_row = {
73
+ "Image": image_filename,
74
+ "Height (m)": round(height, 2),
75
+ "Species": species_guess,
76
+ "Latitude": latitude,
77
+ "Longitude": longitude,
78
+ "Timestamp": timestamp
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  }
80
+ df = pd.concat([df, pd.DataFrame([new_row])], ignore_index=True)
81
+ df.to_csv(CSV_FILE, index=False)
82
+
83
+ return f"Estimated Height: {round(height, 2)} meters\nSpecies: {species_guess}", species_info
84
+
85
+ # Gradio UI
86
+ iface = gr.Interface(
87
+ fn=process,
88
+ inputs=[
89
+ gr.Image(label="Capture Tree + Person Image"),
90
+ gr.Textbox(label="Species Guess (optional)", placeholder="E.g. Neem Tree"),
91
+ gr.Textbox(label="Latitude (optional)", placeholder="Enter from GPS or manually"),
92
+ gr.Textbox(label="Longitude (optional)", placeholder="Enter from GPS or manually")
93
+ ],
94
+ outputs=[
95
+ gr.Text(label="Result"),
96
+ gr.Text(label="Species Information (Wikipedia)")
97
+ ],
98
+ title="🌲 Tree Analyzer - Height + Species + Location",
99
+ description="Upload a photo with a tree and a person. Optionally enter location. Returns estimated height and species info. Logs all data to CSV."
100
+ )
101
+
102
+ iface.launch()