nagasurendra commited on
Commit
4f2217a
·
verified ·
1 Parent(s): 4e407cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -12
app.py CHANGED
@@ -6,7 +6,7 @@ import os
6
  import json
7
  import logging
8
  import matplotlib.pyplot as plt
9
- import csv # Added for flight logs
10
  from datetime import datetime
11
  from collections import Counter
12
  from typing import List, Dict, Any, Optional
@@ -14,6 +14,8 @@ from ultralytics import YOLO
14
  import ultralytics
15
  import time
16
  import piexif
 
 
17
 
18
  # Set YOLO config directory
19
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
@@ -43,9 +45,9 @@ detected_issues: List[str] = []
43
  gps_coordinates: List[List[float]] = []
44
  last_metrics: Dict[str, Any] = {}
45
  frame_count: int = 0
46
- SAVE_IMAGE_INTERVAL = 1 # Save every frame with detections
47
 
48
- # Detection classes (aligned with model classes, excluding 'Crocodile')
49
  DETECTION_CLASSES = ["Longitudinal", "Pothole", "Transverse"]
50
 
51
  # Debug: Check environment
@@ -59,9 +61,24 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
59
  print(f"Using device: {device}")
60
  model = YOLO('./data/best.pt').to(device)
61
  if device == "cuda":
62
- model.half() # Use half-precision (FP16)
63
  print(f"Model classes: {model.names}")
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
66
  map_path = os.path.join(OUTPUT_DIR, "map_temp.png")
67
  plt.figure(figsize=(4, 4))
@@ -110,9 +127,9 @@ def write_flight_log(frame_count: int, gps_coord: List[float], timestamp: str) -
110
  def check_image_quality(frame: np.ndarray, input_resolution: int) -> bool:
111
  height, width, _ = frame.shape
112
  frame_resolution = width * height
113
- if frame_resolution < 12_000_000: # NHAI requires 12 MP
114
  log_entries.append(f"Frame {frame_count}: Resolution {width}x{height} ({frame_resolution/1e6:.2f}MP) below 12MP, non-compliant")
115
- if frame_resolution < input_resolution: # Ensure output is not below input
116
  log_entries.append(f"Frame {frame_count}: Output resolution {width}x{height} below input resolution")
117
  return False
118
  return True
@@ -149,17 +166,24 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
149
  log_entries.clear()
150
  last_metrics = {}
151
 
 
 
 
 
 
 
 
152
  if video is None:
153
  log_entries.append("Error: No video uploaded")
154
  logging.error("No video uploaded")
155
- return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None
156
 
157
  start_time = time.time()
158
  cap = cv2.VideoCapture(video)
159
  if not cap.isOpened():
160
  log_entries.append("Error: Could not open video file")
161
  logging.error("Could not open video file")
162
- return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None
163
 
164
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
165
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
@@ -173,7 +197,7 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
173
 
174
  out_width, out_height = resize_width, resize_height
175
  output_path = os.path.join(OUTPUT_DIR, "processed_output.mp4")
176
- codecs = [('mp4v', '.mp4'), ('XVID', '.avi'), ('MJPG', '.avi')] # Prioritize mp4v
177
  out = None
178
  for codec, ext in codecs:
179
  fourcc = cv2.VideoWriter_fourcc(*codec)
@@ -192,7 +216,7 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
192
  log_entries.append("Error: All codecs failed to initialize video writer")
193
  logging.error("All codecs failed to initialize video writer")
194
  cap.release()
195
- return "processed_output.mp4", json.dumps({"error": "All codecs failed"}, indent=2), "\n".join(log_entries), [], None, None
196
 
197
  processed_frames = 0
198
  all_detections = []
@@ -344,13 +368,21 @@ def process_video(video, resize_width=4000, resize_height=3000, frame_skip=5):
344
  chart_path = generate_line_chart()
345
  map_path = generate_map(gps_coordinates[-5:], all_detections)
346
 
 
 
 
 
347
  return (
348
  output_path,
349
  json.dumps(last_metrics, indent=2),
350
  "\n".join(log_entries[-10:]),
351
  detected_issues,
352
  chart_path,
353
- map_path
 
 
 
 
354
  )
355
 
356
  # Gradio interface
@@ -373,11 +405,70 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="orange")) as iface:
373
  map_output = gr.Image(label="Issue Locations Map")
374
  with gr.Row():
375
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
376
 
377
  process_btn.click(
378
  process_video,
379
  inputs=[video_input, width_slider, height_slider, skip_slider],
380
- outputs=[video_output, metrics_output, logs_output, issue_gallery, chart_output, map_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  )
382
 
383
  if __name__ == "__main__":
 
6
  import json
7
  import logging
8
  import matplotlib.pyplot as plt
9
+ import csv
10
  from datetime import datetime
11
  from collections import Counter
12
  from typing import List, Dict, Any, Optional
 
14
  import ultralytics
15
  import time
16
  import piexif
17
+ import zipfile
18
+ import shutil
19
 
20
  # Set YOLO config directory
21
  os.environ["YOLO_CONFIG_DIR"] = "/tmp/Ultralytics"
 
45
  gps_coordinates: List[List[float]] = []
46
  last_metrics: Dict[str, Any] = {}
47
  frame_count: int = 0
48
+ SAVE_IMAGE_INTERVAL = 1
49
 
50
+ # Detection classes
51
  DETECTION_CLASSES = ["Longitudinal", "Pothole", "Transverse"]
52
 
53
  # Debug: Check environment
 
61
  print(f"Using device: {device}")
62
  model = YOLO('./data/best.pt').to(device)
63
  if device == "cuda":
64
+ model.half()
65
  print(f"Model classes: {model.names}")
66
 
67
+ def zip_directory(folder_path: str, zip_path: str) -> str:
68
+ """Zip all files in a directory."""
69
+ try:
70
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
71
+ for root, _, files in os.walk(folder_path):
72
+ for file in files:
73
+ file_path = os.path.join(root, file)
74
+ arcname = os.path.relpath(file_path, folder_path)
75
+ zipf.write(file_path, arcname)
76
+ return zip_path
77
+ except Exception as e:
78
+ logging.error(f"Failed to zip {folder_path}: {str(e)}")
79
+ log_entries.append(f"Error: Failed to zip {folder_path}: {str(e)}")
80
+ return ""
81
+
82
  def generate_map(gps_coords: List[List[float]], items: List[Dict[str, Any]]) -> str:
83
  map_path = os.path.join(OUTPUT_DIR, "map_temp.png")
84
  plt.figure(figsize=(4, 4))
 
127
  def check_image_quality(frame: np.ndarray, input_resolution: int) -> bool:
128
  height, width, _ = frame.shape
129
  frame_resolution = width * height
130
+ if frame_resolution < 12_000_000:
131
  log_entries.append(f"Frame {frame_count}: Resolution {width}x{height} ({frame_resolution/1e6:.2f}MP) below 12MP, non-compliant")
132
+ if frame_resolution < input_resolution:
133
  log_entries.append(f"Frame {frame_count}: Output resolution {width}x{height} below input resolution")
134
  return False
135
  return True
 
166
  log_entries.clear()
167
  last_metrics = {}
168
 
169
+ # Clear previous outputs
170
+ for dir in [CAPTURED_FRAMES_DIR, FLIGHT_LOG_DIR, OUTPUT_DIR]:
171
+ if os.path.exists(dir):
172
+ shutil.rmtree(dir)
173
+ os.makedirs(dir, exist_ok=True)
174
+ os.chmod(dir, 0o777)
175
+
176
  if video is None:
177
  log_entries.append("Error: No video uploaded")
178
  logging.error("No video uploaded")
179
+ return "processed_output.mp4", json.dumps({"error": "No video uploaded"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
180
 
181
  start_time = time.time()
182
  cap = cv2.VideoCapture(video)
183
  if not cap.isOpened():
184
  log_entries.append("Error: Could not open video file")
185
  logging.error("Could not open video file")
186
+ return "processed_output.mp4", json.dumps({"error": "Could not open video file"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
187
 
188
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
189
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
197
 
198
  out_width, out_height = resize_width, resize_height
199
  output_path = os.path.join(OUTPUT_DIR, "processed_output.mp4")
200
+ codecs = [('mp4v', '.mp4'), ('XVID', '.avi'), ('MJPG', '.avi')]
201
  out = None
202
  for codec, ext in codecs:
203
  fourcc = cv2.VideoWriter_fourcc(*codec)
 
216
  log_entries.append("Error: All codecs failed to initialize video writer")
217
  logging.error("All codecs failed to initialize video writer")
218
  cap.release()
219
+ return "processed_output.mp4", json.dumps({"error": "All codecs failed"}, indent=2), "\n".join(log_entries), [], None, None, None, None, None, None
220
 
221
  processed_frames = 0
222
  all_detections = []
 
368
  chart_path = generate_line_chart()
369
  map_path = generate_map(gps_coordinates[-5:], all_detections)
370
 
371
+ # Zip captured_frames and flight_logs
372
+ images_zip = zip_directory(CAPTURED_FRAMES_DIR, os.path.join(OUTPUT_DIR, "captured_frames.zip"))
373
+ logs_zip = zip_directory(FLIGHT_LOG_DIR, os.path.join(OUTPUT_DIR, "flight_logs.zip"))
374
+
375
  return (
376
  output_path,
377
  json.dumps(last_metrics, indent=2),
378
  "\n".join(log_entries[-10:]),
379
  detected_issues,
380
  chart_path,
381
+ map_path,
382
+ submission_json_path,
383
+ images_zip,
384
+ logs_zip,
385
+ output_path # Duplicate for video download
386
  )
387
 
388
  # Gradio interface
 
405
  map_output = gr.Image(label="Issue Locations Map")
406
  with gr.Row():
407
  logs_output = gr.Textbox(label="Logs", lines=5, interactive=False)
408
+ with gr.Row():
409
+ gr.Markdown("## Download Results")
410
+ with gr.Row():
411
+ json_download = gr.File(label="Download Data Lake JSON", visible=False)
412
+ images_zip_download = gr.File(label="Download Geotagged Images (ZIP)", visible=False)
413
+ logs_zip_download = gr.File(label="Download Flight Logs (ZIP)", visible=False)
414
+ video_download = gr.File(label="Download Processed Video", visible=False)
415
+
416
+ def update_download_buttons(*outputs):
417
+ video_path, metrics, logs, issues, chart, map, json_path, images_zip, logs_zip, video_file = outputs
418
+ return (
419
+ video_path,
420
+ metrics,
421
+ logs,
422
+ issues,
423
+ chart,
424
+ map,
425
+ gr.File.update(value=json_path, visible=True),
426
+ gr.File.update(value=images_zip, visible=True),
427
+ gr.File.update(value=logs_zip, visible=True),
428
+ gr.File.update(value=video_file, visible=True)
429
+ )
430
 
431
  process_btn.click(
432
  process_video,
433
  inputs=[video_input, width_slider, height_slider, skip_slider],
434
+ outputs=[
435
+ video_output,
436
+ metrics_output,
437
+ logs_output,
438
+ issue_gallery,
439
+ chart_output,
440
+ map_output,
441
+ json_download,
442
+ images_zip_download,
443
+ logs_zip_download,
444
+ video_download
445
+ ]
446
+ ).then(
447
+ update_download_buttons,
448
+ inputs=[
449
+ video_output,
450
+ metrics_output,
451
+ logs_output,
452
+ issue_gallery,
453
+ chart_output,
454
+ map_output,
455
+ json_download,
456
+ images_zip_download,
457
+ logs_zip_download,
458
+ video_download
459
+ ],
460
+ outputs=[
461
+ video_output,
462
+ metrics_output,
463
+ logs_output,
464
+ issue_gallery,
465
+ chart_output,
466
+ map_output,
467
+ json_download,
468
+ images_zip_download,
469
+ logs_zip_download,
470
+ video_download
471
+ ]
472
  )
473
 
474
  if __name__ == "__main__":