nagasurendra commited on
Commit
486823b
·
verified ·
1 Parent(s): d552a3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -56
app.py CHANGED
@@ -2,119 +2,256 @@ import cv2
2
  import torch
3
  import gradio as gr
4
  import numpy as np
5
- from ultralytics import YOLO, __version__ as ultralytics_version
6
  import time
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
  # Debug: Check environment
9
  print(f"Torch version: {torch.__version__}")
10
  print(f"Gradio version: {gr.__version__}")
11
- print(f"Ultralytics version: {ultralytics_version}")
12
  print(f"CUDA available: {torch.cuda.is_available()}")
13
 
14
  # Load custom YOLO model
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Using device: {device}")
17
  model = YOLO('./data/best.pt').to(device)
18
- print(f"Model classes: {model.names}") # Print classes (should include cracks, potholes)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- def process_video(video, resize_width=640, resize_height=480, frame_skip=1):
21
  if video is None:
22
- return "Error: No video uploaded"
23
-
24
- # Start timer
 
25
  start_time = time.time()
26
-
27
- # Open input video
28
  cap = cv2.VideoCapture(video)
29
  if not cap.isOpened():
30
- return "Error: Could not open video file"
31
-
32
- # Get input video properties
 
33
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
  fps = cap.get(cv2.CAP_PROP_FPS)
36
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
37
  expected_duration = total_frames / fps
 
 
38
  print(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
39
-
40
- # Set output resolution
41
  out_width, out_height = resize_width, resize_height
42
- print(f"Output resolution: {out_width}x{out_height}")
43
-
44
- # Set up video writer
45
  output_path = "processed_output.mp4"
46
- fourcc = cv2.VideoWriter_fourcc(*'mp4v') # Use 'H264' if mp4v fails
47
  out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
48
-
49
- frame_count = 0
50
  processed_frames = 0
51
-
 
52
  while True:
53
  ret, frame = cap.read()
54
  if not ret:
55
  break
56
-
57
  frame_count += 1
58
-
59
- # Skip frames if frame_skip > 1
60
  if frame_count % frame_skip != 0:
61
  continue
62
-
63
  processed_frames += 1
64
  print(f"Processing frame {frame_count}/{total_frames}")
65
-
66
- # Resize frame for faster inference
67
  frame = cv2.resize(frame, (out_width, out_height))
68
-
69
- # Run YOLO inference (detect cracks and potholes)
70
- results = model(frame, verbose=False, conf=0.5) # Confidence threshold 0.5
71
  annotated_frame = results[0].plot()
72
-
73
- # Log detections
74
  for detection in results[0].boxes:
75
  cls = int(detection.cls)
76
  conf = float(detection.conf)
77
- print(f"Frame {frame_count}: Detected {model.names[cls]} with confidence {conf:.2f}")
78
-
79
- # Write annotated frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  out.write(annotated_frame)
81
-
82
- # Duplicate frames if skipping to maintain duration
83
  if frame_skip > 1:
84
  for _ in range(frame_skip - 1):
85
  if frame_count + 1 <= total_frames:
86
  out.write(annotated_frame)
87
  frame_count += 1
88
-
89
- # Release resources
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  cap.release()
91
  out.release()
92
-
93
- # Verify output duration
94
  cap = cv2.VideoCapture(output_path)
95
  output_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
96
  output_fps = cap.get(cv2.CAP_PROP_FPS)
97
  output_duration = output_frames / output_fps
98
  cap.release()
99
-
 
 
100
  print(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
101
  print(f"Processing time: {time.time() - start_time:.2f} seconds")
102
-
103
- return output_path
 
 
 
 
 
 
 
 
 
 
104
 
105
  # Gradio interface
106
- iface = gr.Interface(
107
- fn=process_video,
108
- inputs=[
109
- gr.Video(label="Upload Video"),
110
- gr.Slider(minimum=320, maximum=1280, value=640, label="Output Width", step=1),
111
- gr.Slider(minimum=240, maximum=720, value=480, label="Output Height", step=1),
112
- gr.Slider(minimum=1, maximum=5, value=1, label="Frame Skip (1 = process all frames)", step=1)
113
- ],
114
- outputs=gr.Video(label="Processed Video"),
115
- title="Crack and Pothole Detection with YOLO",
116
- description="Upload a video to detect cracks and potholes. Adjust resolution and frame skip for faster processing."
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  if __name__ == "__main__":
120
  iface.launch()
 
2
  import torch
3
  import gradio as gr
4
  import numpy as np
5
+ from ultralytics import YOLO
6
  import time
7
+ import os
8
+ import json
9
+ import logging
10
+ import matplotlib.pyplot as plt
11
+ from datetime import datetime
12
+ from collections import Counter
13
+ from typing import List, Dict, Any, Optional
14
+
15
+ # Set up logging
16
+ logging.basicConfig(
17
+ filename="app.log",
18
+ level=logging.INFO,
19
+ format="%(asctime)s - %(levelname)s - %(message)s"
20
+ )
21
+
22
+ # Directories
23
+ CAPTURED_FRAMES_DIR = "captured_frames"
24
+ OUTPUT_DIR = "outputs"
25
+ os.makedirs(CAPTURED_FRAMES_DIR, exist_ok=True)
26
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
27
+ os.chmod(CAPTURED_FRAMES_DIR, 0o777)
28
+ os.chmod(OUTPUT_DIR, 0o777)
29
+
30
+ # Global variables
31
+ log_entries: List[str] = []
32
+ detected_counts: List[int] = []
33
+ detected_issues: List[str] = []
34
+ gps_coordinates: List[List[float]] = []
35
+ last_metrics: Dict[str, Any] = {}
36
+ frame_count: int = 0
37
 
38
  # Debug: Check environment
39
  print(f"Torch version: {torch.__version__}")
40
  print(f"Gradio version: {gr.__version__}")
41
+ print(f"Ultralytics version: {YOLO.__version__}")
42
  print(f"CUDA available: {torch.cuda.is_available()}")
43
 
44
  # Load custom YOLO model
45
  device = "cuda" if torch.cuda.is_available() else "cpu"
46
  print(f"Using device: {device}")
47
  model = YOLO('./data/best.pt').to(device)
48
+ if device == "cuda":
49
+ model.half() # Use half-precision (FP16)
50
+ print(f"Model classes: {model.names}")
51
+
52
+ # Mock service functions (replace with actual implementations if available)
53
+ def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
54
+ """Mock map generation: returns a placeholder image path."""
55
+ map_path = "map_temp.png"
56
+ plt.figure(figsize=(4, 4))
57
+ plt.scatter([x[1] for x in gps_coords], [x[0] for x in gps_coords], c='blue', label='GPS Points')
58
+ plt.title("Mock Issue Locations Map")
59
+ plt.xlabel("Longitude")
60
+ plt.ylabel("Latitude")
61
+ plt.legend()
62
+ plt.savefig(map_path)
63
+ plt.close()
64
+ return map_path
65
+
66
+ def send_to_salesforce(data: Dict[str, Any]) -> None:
67
+ """Mock Salesforce dispatch: logs data."""
68
+ logging.info(f"Mock Salesforce dispatch: {json.dumps(data, indent=2)}")
69
+
70
+ def update_metrics(detections: List[Dict[str, Any]]) -> Dict[str, Any]:
71
+ """Compute detection metrics."""
72
+ counts = Counter([det["label"] for det in detections])
73
+ return {
74
+ "items": [{"type": k, "count": v} for k, v in counts.items()],
75
+ "total_detections": len(detections),
76
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
77
+ }
78
+
79
+ def generate_line_chart() -> Optional[str]:
80
+ """Generate detection trend chart."""
81
+ if not detected_counts:
82
+ return None
83
+ plt.figure(figsize=(4, 2))
84
+ plt.plot(detected_counts[-50:], marker='o', color='#FF8C00')
85
+ plt.title("Detections Over Time")
86
+ plt.xlabel("Frame")
87
+ plt.ylabel("Count")
88
+ plt.grid(True)
89
+ plt.tight_layout()
90
+ chart_path = "chart_temp.png"
91
+ plt.savefig(chart_path)
92
+ plt.close()
93
+ return chart_path
94
+
95
+ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
96
+ global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
97
+ frame_count = 0
98
+ detected_counts.clear()
99
+ detected_issues.clear()
100
+ gps_coordinates.clear()
101
+ log_entries.clear()
102
+ last_metrics = {}
103
 
 
104
  if video is None:
105
+ log_entries.append("Error: No video uploaded")
106
+ logging.error("No video uploaded")
107
+ return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None
108
+
109
  start_time = time.time()
 
 
110
  cap = cv2.VideoCapture(video)
111
  if not cap.isOpened():
112
+ log_entries.append("Error: Could not open video file")
113
+ logging.error("Could not open video file")
114
+ return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None
115
+
116
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
117
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
118
  fps = cap.get(cv2.CAP_PROP_FPS)
119
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
120
  expected_duration = total_frames / fps
121
+ log_entries.append(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
122
+ logging.info(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
123
  print(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
124
+
 
125
  out_width, out_height = resize_width, resize_height
 
 
 
126
  output_path = "processed_output.mp4"
127
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
128
  out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
129
+
 
130
  processed_frames = 0
131
+ all_detections = []
132
+
133
  while True:
134
  ret, frame = cap.read()
135
  if not ret:
136
  break
 
137
  frame_count += 1
 
 
138
  if frame_count % frame_skip != 0:
139
  continue
 
140
  processed_frames += 1
141
  print(f"Processing frame {frame_count}/{total_frames}")
142
+
 
143
  frame = cv2.resize(frame, (out_width, out_height))
144
+ results = model(frame, verbose=False, conf=0.5, iou=0.7)
 
 
145
  annotated_frame = results[0].plot()
146
+
147
+ frame_detections = []
148
  for detection in results[0].boxes:
149
  cls = int(detection.cls)
150
  conf = float(detection.conf)
151
+ box = detection.xyxy[0].cpu().numpy().astype(int).tolist() # [x_min, y_min, x_max, y_max]
152
+ label = model.names[cls]
153
+ frame_detections.append({"label": label, "box": box, "conf": conf})
154
+ log_entries.append(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
155
+ logging.info(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
156
+
157
+ if frame_detections:
158
+ captured_frame_path = os.path.join(CAPTURED_FRAMES_DIR, f"detected_{frame_count}.jpg")
159
+ cv2.imwrite(captured_frame_path, annotated_frame)
160
+ detected_issues.append(captured_frame_path)
161
+ if len(detected_issues) > 100:
162
+ detected_issues.pop(0)
163
+
164
+ frame_path = os.path.join(OUTPUT_DIR, f"frame_{frame_count:04d}.jpg")
165
+ cv2.imwrite(frame_path, annotated_frame)
166
+
167
+ gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)] # Simulated GPS
168
+ gps_coordinates.append(gps_coord)
169
+ for det in frame_detections:
170
+ det["gps"] = gps_coord
171
+ all_detections.extend(frame_detections)
172
+
173
  out.write(annotated_frame)
 
 
174
  if frame_skip > 1:
175
  for _ in range(frame_skip - 1):
176
  if frame_count + 1 <= total_frames:
177
  out.write(annotated_frame)
178
  frame_count += 1
179
+
180
+ detected_counts.append(len(frame_detections))
181
+ last_metrics = update_metrics(all_detections)
182
+ detection_summary = {
183
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
184
+ "frame": frame_count,
185
+ "cracks": sum(1 for det in frame_detections if det["label"] == "crack"),
186
+ "potholes": sum(1 for det in frame_detections if det["label"] == "pothole"),
187
+ "gps": gps_coord,
188
+ "processing_time_ms": (time.time() - start_time) * 1000 / processed_frames if processed_frames else 0
189
+ }
190
+ log_entries.append(json.dumps(detection_summary, indent=2))
191
+ logging.info(json.dumps(detection_summary, indent=2))
192
+ if len(log_entries) > 100:
193
+ log_entries.pop(0)
194
+
195
+ send_to_salesforce({
196
+ "detections": frame_detections,
197
+ "metrics": last_metrics,
198
+ "timestamp": detection_summary["timestamp"],
199
+ "frame_count": frame_count,
200
+ "gps_coordinates": gps_coord
201
+ })
202
+
203
  cap.release()
204
  out.release()
205
+
 
206
  cap = cv2.VideoCapture(output_path)
207
  output_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
208
  output_fps = cap.get(cv2.CAP_PROP_FPS)
209
  output_duration = output_frames / output_fps
210
  cap.release()
211
+
212
+ log_entries.append(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
213
+ logging.info(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
214
  print(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
215
  print(f"Processing time: {time.time() - start_time:.2f} seconds")
216
+
217
+ chart_path = generate_line_chart()
218
+ map_path = generate_map(gps_coordinates[-5:], all_detections)
219
+
220
+ return (
221
+ output_path,
222
+ json.dumps(last_metrics, indent=2),
223
+ "\n".join(log_entries[-10:]),
224
+ detected_issues,
225
+ chart_path,
226
+ map_path
227
+ )
228
 
229
  # Gradio interface
230
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
231
+ gr.Markdown("# Crack and Pothole Detection Dashboard")
232
+ with gr.Row():
233
+ with gr.Column(scale=3):
234
+ video_input = gr.Video(label="Upload Video")
235
+ width_slider = gr.Slider(320, 1280, value=320, label="Output Width", step=1)
236
+ height_slider = gr.Slider(240, 720, value=240, label="Output Height", step=1)
237
+ skip_slider = gr.Slider(1, 10, value=5, label="Frame Skip", step=1)
238
+ process_btn = gr.Button("Process Video", variant="primary")
239
+ with gr.Column(scale=1):
240
+ metrics_output = gr.Textbox(label="Detection Metrics", lines=10, interactive=False)
241
+ with gr.Row():
242
+ video_output = gr.Video(label="Processed Video")
243
+ issue_gallery = gr.Gallery(label="Detected Issues", columns=4, height="auto")
244
+ with gr.Row():
245
+ chart_output = gr.Image(label="Detection Trend")
246
+ map_output = gr.Image(label="Issue Locations Map")
247
+ with gr.Row():
248
+ logs_output = gr.Textbox(label="Logs", lines=8, interactive=False)
249
+
250
+ process_btn.click(
251
+ process_video,
252
+ inputs=[video_input, width_slider, height_slider, skip_slider],
253
+ outputs=[video_output, metrics_output, logs_output, issue_gallery, chart_output, map_output]
254
+ )
255
 
256
  if __name__ == "__main__":
257
  iface.launch()