nagasurendra commited on
Commit
6070e3c
·
verified ·
1 Parent(s): 2e47361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +246 -79
app.py CHANGED
@@ -3,109 +3,276 @@ import torch
3
  import gradio as gr
4
  import numpy as np
5
  import os
 
 
6
  import matplotlib.pyplot as plt
7
- from ultralytics import YOLO, __version__ as ultralytics_version
8
- import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  # Debug: Check environment
11
  print(f"Torch version: {torch.__version__}")
12
  print(f"Gradio version: {gr.__version__}")
13
- print(f"Ultralytics version: {ultralytics_version}")
14
  print(f"CUDA available: {torch.cuda.is_available()}")
15
 
16
- # Load YOLOv8 model
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print(f"Using device: {device}")
19
  model = YOLO('./data/best.pt').to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def process_video(video, output_folder="detected_frames", plot_graphs=False):
22
  if video is None:
23
- yield "Error: No video uploaded", []
24
- return
25
-
26
- # Create output folder if it doesn't exist
27
- if not os.path.exists(output_folder):
28
- os.makedirs(output_folder)
29
-
30
  cap = cv2.VideoCapture(video)
31
  if not cap.isOpened():
32
- yield "Error: Could not open video file", []
33
- return
34
-
35
- frame_width, frame_height = 320, 240 # Smaller resolution
36
- frame_count = 0
37
- frame_skip = 5 # Process every 5th frame
38
- max_frames = 100 # Limit for testing
39
- confidence_scores = [] # Store confidence scores for plotting
40
- detected_frame_paths = [] # Store paths of frames with detections
41
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  while True:
43
  ret, frame = cap.read()
44
- if not ret or frame_count > max_frames:
45
  break
46
-
47
  frame_count += 1
48
  if frame_count % frame_skip != 0:
49
  continue
50
-
51
- frame = cv2.resize(frame, (frame_width, frame_height))
52
- print(f"Processing frame {frame_count}")
53
-
54
- # Run YOLOv8 inference
55
- results = model(frame)
56
-
57
- # Save and yield frame if objects are detected
58
- if results[0].boxes is not None and len(results[0].boxes) > 0:
59
- annotated_frame = results[0].plot()
60
- frame_filename = os.path.join(output_folder, f"frame_{frame_count:04d}.jpg")
61
- cv2.imwrite(frame_filename, annotated_frame)
62
- detected_frame_paths.append(frame_filename)
63
-
64
- # Collect confidence scores for plotting
65
- confs = results[0].boxes.conf.cpu().numpy()
66
- confidence_scores.extend(confs)
67
-
68
- # Yield current status and gallery
69
- yield f"Processed frame {frame_count} with detections", detected_frame_paths[:]
70
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  cap.release()
72
-
73
- # Generate confidence score plot if requested
74
- if plot_graphs and confidence_scores:
75
- plt.figure(figsize=(10, 5))
76
- plt.hist(confidence_scores, bins=20, color='blue', alpha=0.7)
77
- plt.title('Distribution of Confidence Scores')
78
- plt.xlabel('Confidence Score')
79
- plt.ylabel('Frequency')
80
- graph_path = os.path.join(output_folder, "confidence_histogram.png")
81
- plt.savefig(graph_path)
82
- plt.close()
83
- detected_frame_paths.append(graph_path)
84
-
85
- # Final yield with all results
86
- status = f"Saved {len(detected_frame_paths)} frames with detections in {output_folder}. {f'Graph saved as {graph_path}' if plot_graphs and confidence_scores else ''}"
87
- yield status, detected_frame_paths
 
 
 
 
 
88
 
89
  # Gradio interface
90
- with gr.Blocks() as iface:
91
- gr.Markdown("# YOLOv8 Object Detection - Real-time Frame Output")
92
- gr.Markdown("Upload a short video to view frames with detections immediately in a gallery. Optionally generate a confidence score graph.")
93
-
94
  with gr.Row():
95
- video_input = gr.Video(label="Upload Video")
96
- output_folder = gr.Textbox(label="Output Folder", value="detected_frames")
97
- plot_graphs = gr.Checkbox(label="Generate Confidence Score Graph", value=False)
98
-
99
- submit_button = gr.Button("Process Video")
100
-
101
- status_output = gr.Text(label="Status")
102
- gallery_output = gr.Gallery(label="Detected Frames and Graph", preview=True, columns=3)
103
-
104
- submit_button.click(
105
- fn=process_video,
106
- inputs=[video_input, output_folder, plot_graphs],
107
- outputs=[status_output, gallery_output],
108
- concurrency_limit=1
 
 
 
 
 
 
 
109
  )
110
 
111
  if __name__ == "__main__":
 
3
  import gradio as gr
4
  import numpy as np
5
  import os
6
+ import json
7
+ import logging
8
  import matplotlib.pyplot as plt
9
+ from datetime import datetime
10
+ from collections import Counter
11
+ from typing import List, Dict, Any, Optional
12
+ from ultralytics import YOLO
13
+ import ultralytics
14
+ import time
15
+
16
+ # Set YOLO config directory
17
+ os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
18
+
19
+ # Set up logging
20
+ logging.basicConfig(
21
+ filename="app.log",
22
+ level=logging.INFO,
23
+ format="%(asctime)s - %(levelname)s - %(message)s"
24
+ )
25
+
26
+ # Directories
27
+ CAPTURED_FRAMES_DIR = "captured_frames"
28
+ OUTPUT_DIR = "outputs"
29
+ os.makedirs(CAPTURED_FRAMES_DIR, exist_ok=True)
30
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
31
+ os.chmod(CAPTURED_FRAMES_DIR, 0o777)
32
+ os.chmod(OUTPUT_DIR, 0o777)
33
+
34
+ # Global variables
35
+ log_entries: List[str] = []
36
+ detected_counts: List[int] = []
37
+ detected_issues: List[str] = []
38
+ gps_coordinates: List[List[float]] = []
39
+ last_metrics: Dict[str, Any] = {}
40
+ frame_count: int = 0
41
+ SAVE_IMAGE_INTERVAL = 1 # Save every frame with detections
42
 
43
  # Debug: Check environment
44
  print(f"Torch version: {torch.__version__}")
45
  print(f"Gradio version: {gr.__version__}")
46
+ print(f"Ultralytics version: {ultralytics.__version__}")
47
  print(f"CUDA available: {torch.cuda.is_available()}")
48
 
49
+ # Load custom YOLO model
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
  print(f"Using device: {device}")
52
  model = YOLO('./data/best.pt').to(device)
53
+ if device == "cuda":
54
+ model.half() # Use half-precision (FP16)
55
+ print(f"Model classes: {model.names}")
56
+
57
+ # Mock service functions
58
+ def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
59
+ map_path = "map_temp.png"
60
+ plt.figure(figsize=(4, 4))
61
+ plt.scatter([x[1] for x in gps_coords], [x[0] for x in gps_coords], c='blue', label='GPS Points')
62
+ plt.title("Issue Locations Map")
63
+ plt.xlabel("Longitude")
64
+ plt.ylabel("Latitude")
65
+ plt.legend()
66
+ plt.savefig(map_path)
67
+ plt.close()
68
+ return map_path
69
+
70
+ def send_to_salesforce(data: Dict[str, Any]) -> None:
71
+ pass # Minimal mock
72
+
73
+ def update_metrics(detections: List[Dict[str, Any]]) -> Dict[str, Any]:
74
+ counts = Counter([det["label"] for det in detections])
75
+ return {
76
+ "items": [{"type": k, "count": v} for k, v in counts.items()],
77
+ "total_detections": len(detections),
78
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
79
+ }
80
+
81
+ def generate_line_chart() -> Optional[str]:
82
+ if not detected_counts:
83
+ return None
84
+ plt.figure(figsize=(4, 2))
85
+ plt.plot(detected_counts[-50:], marker='o', color='#FF8C00')
86
+ plt.title("Detections Over Time")
87
+ plt.xlabel("Frame")
88
+ plt.ylabel("Count")
89
+ plt.grid(True)
90
+ plt.tight_layout()
91
+ chart_path = "chart_temp.png"
92
+ plt.savefig(chart_path)
93
+ plt.close()
94
+ return chart_path
95
+
96
+ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
97
+ global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
98
+ frame_count = 0
99
+ detected_counts.clear()
100
+ detected_issues.clear()
101
+ gps_coordinates.clear()
102
+ log_entries.clear()
103
+ last_metrics = {}
104
 
 
105
  if video is None:
106
+ log_entries.append("Error: No video uploaded")
107
+ logging.error("No video uploaded")
108
+ return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None
109
+
110
+ start_time = time.time()
 
 
111
  cap = cv2.VideoCapture(video)
112
  if not cap.isOpened():
113
+ log_entries.append("Error: Could not open video file")
114
+ logging.error("Could not open video file")
115
+ return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None
116
+
117
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
118
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
119
+ fps = cap.get(cv2.CAP_PROP_FPS)
120
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
121
+ expected_duration = total_frames / fps if fps > 0 else 0
122
+ log_entries.append(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
123
+ logging.info(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
124
+ print(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
125
+
126
+ out_width, out_height = resize_width, resize_height
127
+ output_path = "processed_output.mp4"
128
+ fourcc = cv2.VideoWriter_fourcc(*'H264') # Use H264 codec
129
+ out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
130
+ if not out.isOpened():
131
+ log_entries.append("Error: Failed to initialize video writer")
132
+ logging.error("Failed to initialize video writer")
133
+ cap.release()
134
+ return "processed_output.mp4", json.dumps({"error": "Failed to initialize video writer"}, indent=2), "\n".join(log_entries), [], None, None
135
+
136
+ processed_frames = 0
137
+ all_detections = []
138
+ frame_times = []
139
+ detection_frame_count = 0
140
+ output_frame_count = 0
141
+
142
  while True:
143
  ret, frame = cap.read()
144
+ if not ret:
145
  break
 
146
  frame_count += 1
147
  if frame_count % frame_skip != 0:
148
  continue
149
+ processed_frames += 1
150
+ frame_start = time.time()
151
+
152
+ frame = cv2.resize(frame, (out_width, out_height))
153
+ results = model(frame, verbose=False, conf=0.5, iou=0.7)
154
+ annotated_frame = results[0].plot()
155
+
156
+ frame_detections = []
157
+ for detection in results[0].boxes:
158
+ cls = int(detection.cls)
159
+ conf = float(detection.conf)
160
+ box = detection.xyxy[0].cpu().numpy().astype(int).tolist()
161
+ label = model.names[cls]
162
+ frame_detections.append({"label": label, "box": box, "conf": conf})
163
+ log_entries.append(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
164
+ logging.info(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
165
+
166
+ if frame_detections:
167
+ detection_frame_count += 1
168
+ if detection_frame_count % SAVE_IMAGE_INTERVAL == 0:
169
+ captured_frame_path = os.path.join(CAPTURED_FRAMES_DIR, f"detected_{frame_count}.jpg")
170
+ if not cv2.imwrite(captured_frame_path, annotated_frame):
171
+ log_entries.append(f"Error: Failed to save {captured_frame_path}")
172
+ logging.error(f"Failed to save {captured_frame_path}")
173
+ else:
174
+ detected_issues.append(captured_frame_path)
175
+ if len(detected_issues) > 100:
176
+ detected_issues.pop(0)
177
+
178
+ # Write frame and duplicates
179
+ out.write(annotated_frame)
180
+ output_frame_count += 1
181
+ if frame_skip > 1:
182
+ for _ in range(frame_skip - 1):
183
+ out.write(annotated_frame)
184
+ output_frame_count += 1
185
+
186
+ detected_counts.append(len(frame_detections))
187
+ gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)]
188
+ gps_coordinates.append(gps_coord)
189
+ for det in frame_detections:
190
+ det["gps"] = gps_coord
191
+ all_detections.extend(frame_detections)
192
+
193
+ frame_time = (time.time() - frame_start) * 1000
194
+ frame_times.append(frame_time)
195
+ detection_summary = {
196
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
197
+ "frame": frame_count,
198
+ "cracks": sum(1 for det in frame_detections if det["label"] == "crack"),
199
+ "potholes": sum(1 for det in frame_detections if det["label"] == "pothole"),
200
+ "gps": gps_coord,
201
+ "processing_time_ms": frame_time
202
+ }
203
+ log_entries.append(json.dumps(detection_summary, indent=2))
204
+ if len(log_entries) > 50:
205
+ log_entries.pop(0)
206
+
207
+ # Pad remaining frames if needed
208
+ while output_frame_count < total_frames and annotated_frame is not None:
209
+ out.write(annotated_frame)
210
+ output_frame_count += 1
211
+
212
+ last_metrics = update_metrics(all_detections)
213
+ send_to_salesforce({
214
+ "detections": all_detections,
215
+ "metrics": last_metrics,
216
+ "timestamp": detection_summary["timestamp"] if all_detections else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
217
+ "frame_count": frame_count,
218
+ "gps_coordinates": gps_coordinates[-1] if gps_coordinates else [0, 0]
219
+ })
220
+
221
+ cap.release()
222
+ out.release()
223
+
224
+ cap = cv2.VideoCapture(output_path)
225
+ output_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
226
+ output_fps = cap.get(cv2.CAP_PROP_FPS)
227
+ output_duration = output_frames / output_fps if output_fps > 0 else 0
228
  cap.release()
229
+
230
+ total_time = time.time() - start_time
231
+ avg_frame_time = sum(frame_times) / len(frame_times) if frame_times else 0
232
+ log_entries.append(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
233
+ log_entries.append(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
234
+ logging.info(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
235
+ logging.info(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
236
+ print(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
237
+ print(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
238
+
239
+ chart_path = generate_line_chart()
240
+ map_path = generate_map(gps_coordinates[-5:], all_detections)
241
+
242
+ return (
243
+ output_path,
244
+ json.dumps(last_metrics, indent=2),
245
+ "\n".join(log_entries[-10:]),
246
+ detected_issues,
247
+ chart_path,
248
+ map_path
249
+ )
250
 
251
  # Gradio interface
252
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
253
+ gr.Markdown("# Crack and Pothole Detection Dashboard")
 
 
254
  with gr.Row():
255
+ with gr.Column(scale=3):
256
+ video_input = gr.Video(label="Upload Video")
257
+ width_slider = gr.Slider(320, 640, value=320, label="Output Width", step=1)
258
+ height_slider = gr.Slider(240, 480, value=240, label="Output Height", step=1)
259
+ skip_slider = gr.Slider(1, 10, value=5, label="Frame Skip", step=1)
260
+ process_btn = gr.Button("Process Video", variant="primary")
261
+ with gr.Column(scale=1):
262
+ metrics_output = gr.Textbox(label="Detection Metrics", lines=5, interactive=False)
263
+ with gr.Row():
264
+ video_output = gr.Video(label="Processed Video")
265
+ issue_gallery = gr.Gallery(label="Detected Issues", columns=4, height="auto", object_fit="contain")
266
+ with gr.Row():
267
+ chart_output = gr.Image(label="Detection Trend")
268
+ map_output = gr.Image(label="Issue Locations Map")
269
+ with gr.Row():
270
+ logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
271
+
272
+ process_btn.click(
273
+ process_video,
274
+ inputs=[video_input, width_slider, height_slider, skip_slider],
275
+ outputs=[video_output, metrics_output, logs_output, issue_gallery, chart_output, map_output]
276
  )
277
 
278
  if __name__ == "__main__":