nagasurendra commited on
Commit
db24399
·
verified ·
1 Parent(s): c16972a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -113
app.py CHANGED
@@ -6,6 +6,7 @@ import os
6
  import json
7
  import logging
8
  import matplotlib.pyplot as plt
 
9
  from datetime import datetime
10
  from collections import Counter
11
  from typing import List, Dict, Any, Optional
@@ -25,10 +26,13 @@ logging.basicConfig(
25
 
26
  # Directories
27
  CAPTURED_FRAMES_DIR = "captured_frames"
 
28
  OUTPUT_DIR = "outputs"
29
  os.makedirs(CAPTURED_FRAMES_DIR, exist_ok=True)
 
30
  os.makedirs(OUTPUT_DIR, exist_ok=True)
31
  os.chmod(CAPTURED_FRAMES_DIR, 0o777)
 
32
  os.chmod(OUTPUT_DIR, 0o777)
33
 
34
  # Global variables
@@ -38,23 +42,16 @@ detected_issues: List[str] = []
38
  gps_coordinates: List[List[float]] = []
39
  last_metrics: Dict[str, Any] = {}
40
  frame_count: int = 0
41
- SAVE_IMAGE_INTERVAL = 1 # Save every frame with detections
42
 
43
- # Debug: Check environment
44
- print(f"Torch version: {torch.__version__}")
45
- print(f"Gradio version: {gr.__version__}")
46
- print(f"Ultralytics version: {ultralytics.__version__}")
47
- print(f"CUDA available: {torch.cuda.is_available()}")
48
-
49
- # Load custom YOLO model
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
- print(f"Using device: {device}")
52
  model = YOLO('./data/best.pt').to(device)
53
  if device == "cuda":
54
- model.half() # Use half-precision (FP16)
55
- print(f"Model classes: {model.names}")
56
 
57
- # Mock service functions
58
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
59
  map_path = "map_temp.png"
60
  plt.figure(figsize=(4, 4))
@@ -68,7 +65,7 @@ def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) ->
68
  return map_path
69
 
70
  def send_to_salesforce(data: Dict[str, Any]) -> None:
71
- pass # Minimal mock
72
 
73
  def update_metrics(detections: List[Dict[str, Any]]) -> Dict[str, Any]:
74
  counts = Counter([det["label"] for det in detections])
@@ -81,6 +78,7 @@ def update_metrics(detections: List[Dict[str, Any]]) -> Dict[str, Any]:
81
  def generate_line_chart() -> Optional[str]:
82
  if not detected_counts:
83
  return None
 
84
  plt.figure(figsize=(4, 2))
85
  plt.plot(detected_counts[-50:], marker='o', color='#FF8C00')
86
  plt.title("Detections Over Time")
@@ -88,11 +86,20 @@ def generate_line_chart() -> Optional[str]:
88
  plt.ylabel("Count")
89
  plt.grid(True)
90
  plt.tight_layout()
91
- chart_path = "chart_temp.png"
92
  plt.savefig(chart_path)
93
  plt.close()
94
  return chart_path
95
 
 
 
 
 
 
 
 
 
 
 
96
  def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
97
  global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
98
  frame_count = 0
@@ -102,56 +109,30 @@ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
102
  log_entries.clear()
103
  last_metrics = {}
104
 
 
 
 
 
105
  if video is None:
106
  log_entries.append("Error: No video uploaded")
107
- logging.error("No video uploaded")
108
- return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None
109
 
110
- start_time = time.time()
111
  cap = cv2.VideoCapture(video)
112
  if not cap.isOpened():
113
  log_entries.append("Error: Could not open video file")
114
- logging.error("Could not open video file")
115
- return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None
116
 
117
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
118
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
119
  fps = cap.get(cv2.CAP_PROP_FPS)
120
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
121
- expected_duration = total_frames / fps if fps > 0 else 0
122
- log_entries.append(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
123
- logging.info(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
124
- print(f"Input video: {frame_width}x{frame_height}, {fps} FPS, {total_frames} frames, {expected_duration:.2f} seconds")
125
-
126
- out_width, out_height = resize_width, resize_height
127
  output_path = "processed_output.mp4"
128
- # Try codecs in order of preference
129
- codecs = [('mp4v', '.mp4'), ('MJPG', '.avi'), ('XVID', '.avi')]
130
- out = None
131
- for codec, ext in codecs:
132
- fourcc = cv2.VideoWriter_fourcc(*codec)
133
- output_path = f"processed_output{ext}"
134
- out = cv2.VideoWriter(output_path, fourcc, fps, (out_width, out_height))
135
- if out.isOpened():
136
- log_entries.append(f"Using codec: {codec}, output: {output_path}")
137
- logging.info(f"Using codec: {codec}, output: {output_path}")
138
- break
139
- else:
140
- log_entries.append(f"Failed to initialize codec: {codec}")
141
- logging.warning(f"Failed to initialize codec: {codec}")
142
-
143
- if not out or not out.isOpened():
144
- log_entries.append("Error: All codecs failed to initialize video writer")
145
- logging.error("All codecs failed to initialize video writer")
146
- cap.release()
147
- return "processed_output.mp4", json.dumps({"error": "All codecs failed"}, indent=2), "\n".join(log_entries), [], None, None
148
 
149
- processed_frames = 0
150
  all_detections = []
151
  frame_times = []
152
  detection_frame_count = 0
153
- output_frame_count = 0
154
- last_annotated_frame = None
155
 
156
  while True:
157
  ret, frame = cap.read()
@@ -160,77 +141,47 @@ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
160
  frame_count += 1
161
  if frame_count % frame_skip != 0:
162
  continue
163
- processed_frames += 1
164
- frame_start = time.time()
165
 
166
- frame = cv2.resize(frame, (out_width, out_height))
167
  results = model(frame, verbose=False, conf=0.5, iou=0.7)
168
  annotated_frame = results[0].plot()
169
 
 
 
 
 
170
  frame_detections = []
171
  for detection in results[0].boxes:
172
  cls = int(detection.cls)
173
  conf = float(detection.conf)
174
  box = detection.xyxy[0].cpu().numpy().astype(int).tolist()
175
  label = model.names[cls]
176
- if label != 'Crocodile': # Ignore irrelevant class
177
- frame_detections.append({"label": label, "box": box, "conf": conf})
178
- log_entries.append(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
179
- logging.info(f"Frame {frame_count}: Detected {label} with confidence {conf:.2f}")
180
 
181
  if frame_detections:
182
  detection_frame_count += 1
183
  if detection_frame_count % SAVE_IMAGE_INTERVAL == 0:
184
- captured_frame_path = os.path.join(CAPTURED_FRAMES_DIR, f"detected_{frame_count}.jpg")
185
- if not cv2.imwrite(captured_frame_path, annotated_frame):
186
- log_entries.append(f"Error: Failed to save {captured_frame_path}")
187
- logging.error(f"Failed to save {captured_frame_path}")
188
- else:
189
- detected_issues.append(captured_frame_path)
190
- if len(detected_issues) > 100:
191
- detected_issues.pop(0)
192
-
193
- # Write frame and duplicates
194
- out.write(annotated_frame)
195
- output_frame_count += 1
196
- last_annotated_frame = annotated_frame
197
- if frame_skip > 1:
198
- for _ in range(frame_skip - 1):
199
- out.write(annotated_frame)
200
- output_frame_count += 1
201
 
202
- detected_counts.append(len(frame_detections))
203
  gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)]
204
  gps_coordinates.append(gps_coord)
205
  for det in frame_detections:
206
  det["gps"] = gps_coord
207
  all_detections.extend(frame_detections)
208
-
209
- frame_time = (time.time() - frame_start) * 1000
210
  frame_times.append(frame_time)
211
- detection_summary = {
212
- "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
213
- "frame": frame_count,
214
- "longitudinal": sum(1 for det in frame_detections if det["label"] == "Longitudinal"),
215
- "pothole": sum(1 for det in frame_detections if det["label"] == "Pothole"),
216
- "transverse": sum(1 for det in frame_detections if det["label"] == "Transverse"),
217
- "gps": gps_coord,
218
- "processing_time_ms": frame_time
219
- }
220
- log_entries.append(json.dumps(detection_summary, indent=2))
221
- if len(log_entries) > 50:
222
- log_entries.pop(0)
223
-
224
- # Pad remaining frames
225
- while output_frame_count < total_frames and last_annotated_frame is not None:
226
- out.write(last_annotated_frame)
227
- output_frame_count += 1
228
 
229
  last_metrics = update_metrics(all_detections)
230
  send_to_salesforce({
231
  "detections": all_detections,
232
  "metrics": last_metrics,
233
- "timestamp": detection_summary["timestamp"] if all_detections else datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
234
  "frame_count": frame_count,
235
  "gps_coordinates": gps_coordinates[-1] if gps_coordinates else [0, 0]
236
  })
@@ -238,23 +189,10 @@ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
238
  cap.release()
239
  out.release()
240
 
241
- cap = cv2.VideoCapture(output_path)
242
- output_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
243
- output_fps = cap.get(cv2.CAP_PROP_FPS)
244
- output_duration = output_frames / output_fps if output_fps > 0 else 0
245
- cap.release()
246
-
247
- total_time = time.time() - start_time
248
- avg_frame_time = sum(frame_times) / len(frame_times) if frame_times else 0
249
- log_entries.append(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
250
- log_entries.append(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
251
- logging.info(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
252
- logging.info(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
253
- print(f"Output video: {output_frames} frames, {output_fps} FPS, {output_duration:.2f} seconds")
254
- print(f"Total processing time: {total_time:.2f} seconds, Avg frame time: {avg_frame_time:.2f} ms, Detection frames: {detection_frame_count}, Output frames: {output_frame_count}")
255
-
256
  chart_path = generate_line_chart()
257
  map_path = generate_map(gps_coordinates[-5:], all_detections)
 
 
258
 
259
  return (
260
  output_path,
@@ -262,12 +200,15 @@ def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
262
  "\n".join(log_entries[-10:]),
263
  detected_issues,
264
  chart_path,
265
- map_path
 
 
266
  )
267
 
268
- # Gradio interface
269
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
270
- gr.Markdown("# Road Defect Detection Dashboard")
 
271
  with gr.Row():
272
  with gr.Column(scale=3):
273
  video_input = gr.Video(label="Upload Video")
@@ -277,20 +218,36 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
277
  process_btn = gr.Button("Process Video", variant="primary")
278
  with gr.Column(scale=1):
279
  metrics_output = gr.Textbox(label="Detection Metrics", lines=5, interactive=False)
 
280
  with gr.Row():
281
  video_output = gr.Video(label="Processed Video")
282
  issue_gallery = gr.Gallery(label="Detected Issues", columns=4, height="auto", object_fit="contain")
 
283
  with gr.Row():
284
  chart_output = gr.Image(label="Detection Trend")
285
  map_output = gr.Image(label="Issue Locations Map")
 
286
  with gr.Row():
287
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
288
 
 
 
 
 
289
  process_btn.click(
290
  process_video,
291
  inputs=[video_input, width_slider, height_slider, skip_slider],
292
- outputs=[video_output, metrics_output, logs_output, issue_gallery, chart_output, map_output]
 
 
 
 
 
 
 
 
 
293
  )
294
 
295
  if __name__ == "__main__":
296
- iface.launch()
 
6
  import json
7
  import logging
8
  import matplotlib.pyplot as plt
9
+ import zipfile
10
  from datetime import datetime
11
  from collections import Counter
12
  from typing import List, Dict, Any, Optional
 
26
 
27
  # Directories
28
  CAPTURED_FRAMES_DIR = "captured_frames"
29
+ ORIGINAL_FRAMES_DIR = "original_frames"
30
  OUTPUT_DIR = "outputs"
31
  os.makedirs(CAPTURED_FRAMES_DIR, exist_ok=True)
32
+ os.makedirs(ORIGINAL_FRAMES_DIR, exist_ok=True)
33
  os.makedirs(OUTPUT_DIR, exist_ok=True)
34
  os.chmod(CAPTURED_FRAMES_DIR, 0o777)
35
+ os.chmod(ORIGINAL_FRAMES_DIR, 0o777)
36
  os.chmod(OUTPUT_DIR, 0o777)
37
 
38
  # Global variables
 
42
  gps_coordinates: List[List[float]] = []
43
  last_metrics: Dict[str, Any] = {}
44
  frame_count: int = 0
45
+ SAVE_IMAGE_INTERVAL = 1
46
 
47
+ # Load model
 
 
 
 
 
 
48
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
49
  model = YOLO('./data/best.pt').to(device)
50
  if device == "cuda":
51
+ model.half()
52
+ print(f"Using {device}, model classes: {model.names}")
53
 
54
+ # Helper functions
55
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
56
  map_path = "map_temp.png"
57
  plt.figure(figsize=(4, 4))
 
65
  return map_path
66
 
67
  def send_to_salesforce(data: Dict[str, Any]) -> None:
68
+ pass # Placeholder
69
 
70
  def update_metrics(detections: List[Dict[str, Any]]) -> Dict[str, Any]:
71
  counts = Counter([det["label"] for det in detections])
 
78
  def generate_line_chart() -> Optional[str]:
79
  if not detected_counts:
80
  return None
81
+ chart_path = "chart_temp.png"
82
  plt.figure(figsize=(4, 2))
83
  plt.plot(detected_counts[-50:], marker='o', color='#FF8C00')
84
  plt.title("Detections Over Time")
 
86
  plt.ylabel("Count")
87
  plt.grid(True)
88
  plt.tight_layout()
 
89
  plt.savefig(chart_path)
90
  plt.close()
91
  return chart_path
92
 
93
+ def create_zip_from_directory(dir_path: str, zip_filename: str) -> str:
94
+ zip_path = os.path.join(OUTPUT_DIR, zip_filename)
95
+ with zipfile.ZipFile(zip_path, 'w') as zipf:
96
+ for root, _, files in os.walk(dir_path):
97
+ for file in files:
98
+ full_path = os.path.join(root, file)
99
+ zipf.write(full_path, arcname=file)
100
+ return zip_path
101
+
102
+ # Main function
103
  def process_video(video, resize_width=320, resize_height=240, frame_skip=5):
104
  global frame_count, last_metrics, detected_counts, detected_issues, gps_coordinates, log_entries
105
  frame_count = 0
 
109
  log_entries.clear()
110
  last_metrics = {}
111
 
112
+ for dir_ in [CAPTURED_FRAMES_DIR, ORIGINAL_FRAMES_DIR]:
113
+ for file in os.listdir(dir_):
114
+ os.remove(os.path.join(dir_, file))
115
+
116
  if video is None:
117
  log_entries.append("Error: No video uploaded")
118
+ return None, json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None, None, None
 
119
 
 
120
  cap = cv2.VideoCapture(video)
121
  if not cap.isOpened():
122
  log_entries.append("Error: Could not open video file")
123
+ return None, json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None, None, None
 
124
 
125
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
126
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
127
  fps = cap.get(cv2.CAP_PROP_FPS)
128
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
 
 
 
 
 
 
129
  output_path = "processed_output.mp4"
130
+ out = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (resize_width, resize_height))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
 
132
  all_detections = []
133
  frame_times = []
134
  detection_frame_count = 0
135
+ start_time = time.time()
 
136
 
137
  while True:
138
  ret, frame = cap.read()
 
141
  frame_count += 1
142
  if frame_count % frame_skip != 0:
143
  continue
 
 
144
 
145
+ frame = cv2.resize(frame, (resize_width, resize_height))
146
  results = model(frame, verbose=False, conf=0.5, iou=0.7)
147
  annotated_frame = results[0].plot()
148
 
149
+ # Save original frame
150
+ original_path = os.path.join(ORIGINAL_FRAMES_DIR, f"frame_{frame_count}.jpg")
151
+ cv2.imwrite(original_path, frame)
152
+
153
  frame_detections = []
154
  for detection in results[0].boxes:
155
  cls = int(detection.cls)
156
  conf = float(detection.conf)
157
  box = detection.xyxy[0].cpu().numpy().astype(int).tolist()
158
  label = model.names[cls]
159
+ frame_detections.append({"label": label, "box": box, "conf": conf})
 
 
 
160
 
161
  if frame_detections:
162
  detection_frame_count += 1
163
  if detection_frame_count % SAVE_IMAGE_INTERVAL == 0:
164
+ captured_path = os.path.join(CAPTURED_FRAMES_DIR, f"frame_{frame_count}.jpg")
165
+ cv2.imwrite(captured_path, annotated_frame)
166
+ detected_issues.append(captured_path)
167
+ if len(detected_issues) > 100:
168
+ detected_issues.pop(0)
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
+ out.write(annotated_frame)
171
  gps_coord = [17.385044 + (frame_count * 0.0001), 78.486671 + (frame_count * 0.0001)]
172
  gps_coordinates.append(gps_coord)
173
  for det in frame_detections:
174
  det["gps"] = gps_coord
175
  all_detections.extend(frame_detections)
176
+ detected_counts.append(len(frame_detections))
177
+ frame_time = (time.time() - start_time) * 1000
178
  frame_times.append(frame_time)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
 
180
  last_metrics = update_metrics(all_detections)
181
  send_to_salesforce({
182
  "detections": all_detections,
183
  "metrics": last_metrics,
184
+ "timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
185
  "frame_count": frame_count,
186
  "gps_coordinates": gps_coordinates[-1] if gps_coordinates else [0, 0]
187
  })
 
189
  cap.release()
190
  out.release()
191
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
  chart_path = generate_line_chart()
193
  map_path = generate_map(gps_coordinates[-5:], all_detections)
194
+ originals_zip = create_zip_from_directory(ORIGINAL_FRAMES_DIR, "original_images.zip")
195
+ annotated_zip = create_zip_from_directory(CAPTURED_FRAMES_DIR, "annotated_images.zip")
196
 
197
  return (
198
  output_path,
 
200
  "\n".join(log_entries[-10:]),
201
  detected_issues,
202
  chart_path,
203
+ map_path,
204
+ originals_zip,
205
+ annotated_zip
206
  )
207
 
208
+ # Gradio UI
209
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
210
+ gr.Markdown("# Crack and Pothole Detection Dashboard")
211
+
212
  with gr.Row():
213
  with gr.Column(scale=3):
214
  video_input = gr.Video(label="Upload Video")
 
218
  process_btn = gr.Button("Process Video", variant="primary")
219
  with gr.Column(scale=1):
220
  metrics_output = gr.Textbox(label="Detection Metrics", lines=5, interactive=False)
221
+
222
  with gr.Row():
223
  video_output = gr.Video(label="Processed Video")
224
  issue_gallery = gr.Gallery(label="Detected Issues", columns=4, height="auto", object_fit="contain")
225
+
226
  with gr.Row():
227
  chart_output = gr.Image(label="Detection Trend")
228
  map_output = gr.Image(label="Issue Locations Map")
229
+
230
  with gr.Row():
231
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
232
 
233
+ with gr.Row():
234
+ originals_zip_out = gr.File(label="Download Original Images (ZIP)")
235
+ annotated_zip_out = gr.File(label="Download Annotated Images (ZIP)")
236
+
237
  process_btn.click(
238
  process_video,
239
  inputs=[video_input, width_slider, height_slider, skip_slider],
240
+ outputs=[
241
+ video_output,
242
+ metrics_output,
243
+ logs_output,
244
+ issue_gallery,
245
+ chart_output,
246
+ map_output,
247
+ originals_zip_out,
248
+ annotated_zip_out
249
+ ]
250
  )
251
 
252
  if __name__ == "__main__":
253
+ iface.launch()