simran0608 commited on
Commit
7317aac
·
verified ·
1 Parent(s): 41fadd8

Update streamlit_app.py

Browse files
Files changed (1) hide show
  1. streamlit_app.py +66 -217
streamlit_app.py CHANGED
@@ -1,11 +1,13 @@
1
  import asyncio
2
  import sys
3
 
 
4
  if sys.platform.startswith('linux') and sys.version_info >= (3, 8):
5
  try:
6
  asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
7
  except Exception:
8
  pass
 
9
  import streamlit as st
10
  from PIL import Image
11
  import numpy as np
@@ -15,7 +17,9 @@ import tempfile
15
  import os
16
  from ultralytics import YOLO
17
  import cv2 as cv
18
- import pandas as pd
 
 
19
 
20
  model_path="best.pt"
21
 
@@ -31,19 +35,17 @@ st.set_page_config(
31
  st.sidebar.title("🚗 Driver Distraction System")
32
  st.sidebar.write("Choose an option below:")
33
 
34
- # Sidebar navigation
35
  page = st.sidebar.radio("Select Feature", [
36
- "Distraction System",
37
- "Real-time Drowsiness Detection",
38
- "Video Drowsiness Detection"
39
  ])
40
 
41
  # --- Class Labels (for YOLO model) ---
 
42
  class_names = ['drinking', 'hair and makeup', 'operating the radio', 'reaching behind',
43
  'safe driving', 'talking on the phone', 'talking to passenger', 'texting']
44
-
45
- # Sidebar Class Name Display
46
- st.sidebar.subheader("Class Names")
47
  for idx, class_name in enumerate(class_names):
48
  st.sidebar.write(f"{idx}: {class_name}")
49
 
@@ -86,233 +88,80 @@ if page == "Distraction System":
86
  else:
87
  st.warning("No distractions detected.")
88
 
89
- else: # Video processing
90
- uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"])
91
-
92
- if uploaded_video is not None:
93
- tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
94
- tfile.write(uploaded_video.read())
95
- temp_input_path = tfile.name
96
- temp_output_path = tempfile.mktemp(suffix="_distraction_detected.mp4")
97
-
98
- st.subheader("Video Information")
99
- cap = cv.VideoCapture(temp_input_path)
100
- fps = cap.get(cv.CAP_PROP_FPS)
101
- width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH))
102
- height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT))
103
- total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
104
- duration = total_frames / fps if fps > 0 else 0
105
- cap.release()
106
-
107
- col1, col2 = st.columns(2)
108
- with col1:
109
- st.metric("Duration", f"{duration:.2f} seconds")
110
- st.metric("Original FPS", f"{fps:.2f}")
111
- with col2:
112
- st.metric("Resolution", f"{width}x{height}")
113
- st.metric("Total Frames", total_frames)
114
-
115
- st.subheader("Original Video Preview")
116
- st.video(uploaded_video)
117
-
118
- if st.button("Process Video for Distraction Detection"):
119
- TARGET_PROCESSING_FPS = 10
120
- # --- NEW: Hyperparameter for the temporal smoothing logic ---
121
- PERSISTENCE_CONFIDENCE_THRESHOLD = 0.40 # Stick with old class if found with >= 40% confidence
122
-
123
- st.info(f"🚀 For faster results, video will be processed at ~{TARGET_PROCESSING_FPS} FPS.")
124
- st.info(f"🧠 Applying temporal smoothing to reduce status flickering (Persistence Threshold: {PERSISTENCE_CONFIDENCE_THRESHOLD*100:.0f}%).")
125
-
126
- progress_bar = st.progress(0, text="Starting video processing...")
127
-
128
- with st.spinner(f"Processing video... This may take a while."):
129
- model = YOLO(model_path)
130
- cap = cv.VideoCapture(temp_input_path)
131
-
132
- fourcc = cv.VideoWriter_fourcc(*'mp4v')
133
- out = cv.VideoWriter(temp_output_path, fourcc, fps, (width, height))
134
-
135
- frame_skip_interval = max(1, round(fps / TARGET_PROCESSING_FPS))
136
-
137
- frame_count = 0
138
- last_best_box_coords = None
139
- last_best_box_label = ""
140
- last_status_text = "Status: Initializing..."
141
- last_status_color = (128, 128, 128)
142
- # --- NEW: State variable to store the last confirmed class ---
143
- last_confirmed_class_name = 'safe driving'
144
-
145
- while cap.isOpened():
146
- ret, frame = cap.read()
147
- if not ret:
148
- break
149
-
150
- frame_count += 1
151
- progress = int((frame_count / total_frames) * 100) if total_frames > 0 else 0
152
- progress_bar.progress(progress, text=f"Analyzing frame {frame_count}/{total_frames}")
153
-
154
- annotated_frame = frame.copy()
155
-
156
- if frame_count % frame_skip_interval == 0:
157
- results = model(annotated_frame)
158
- result = results[0]
159
-
160
- last_best_box_coords = None # Reset box for this processing cycle
161
-
162
- if len(result.boxes) > 0:
163
- boxes = result.boxes
164
- class_names_dict = result.names
165
- confidences = boxes.conf.cpu().numpy()
166
- classes = boxes.cls.cpu().numpy()
167
-
168
- # --- NEW STABILITY LOGIC ---
169
- final_box_to_use = None
170
-
171
- # 1. Check if the last known class exists with reasonable confidence
172
- for i in range(len(boxes)):
173
- current_class_name = class_names_dict[int(classes[i])]
174
- if current_class_name == last_confirmed_class_name and confidences[i] >= PERSISTENCE_CONFIDENCE_THRESHOLD:
175
- final_box_to_use = boxes[i]
176
- break
177
-
178
- # 2. If not, fall back to the highest confidence detection in the current frame
179
- if final_box_to_use is None:
180
- max_conf_idx = confidences.argmax()
181
- final_box_to_use = boxes[max_conf_idx]
182
- # --- END OF NEW LOGIC ---
183
-
184
- # Now, process the determined "final_box_to_use"
185
- x1, y1, x2, y2 = final_box_to_use.xyxy[0].cpu().numpy()
186
- confidence = final_box_to_use.conf[0].cpu().numpy()
187
- class_id = int(final_box_to_use.cls[0].cpu().numpy())
188
- class_name = class_names_dict[class_id]
189
-
190
- # Update the state for the next frames
191
- last_confirmed_class_name = class_name
192
- last_best_box_coords = (int(x1), int(y1), int(x2), int(y2))
193
- last_best_box_label = f"{class_name}: {confidence:.2f}"
194
-
195
- if class_name != 'safe driving':
196
- last_status_text = f"Status: {class_name.replace('_', ' ').title()}"
197
- last_status_color = (0, 0, 255)
198
- else:
199
- last_status_text = "Status: Safe Driving"
200
- last_status_color = (0, 128, 0)
201
- else:
202
- # No detections, reset to safe driving
203
- last_confirmed_class_name = 'safe driving'
204
- last_status_text = "Status: Safe Driving"
205
- last_status_color = (0, 128, 0)
206
-
207
- # Draw annotations on EVERY frame using the last known data
208
- if last_best_box_coords:
209
- cv.rectangle(annotated_frame, (last_best_box_coords[0], last_best_box_coords[1]),
210
- (last_best_box_coords[2], last_best_box_coords[3]), (0, 255, 0), 2)
211
- cv.putText(annotated_frame, last_best_box_label,
212
- (last_best_box_coords[0], last_best_box_coords[1] - 10),
213
- cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
214
-
215
- # Draw status text
216
- font_scale, font_thickness = 1.0, 2
217
- (text_w, text_h), _ = cv.getTextSize(last_status_text, cv.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness)
218
- padding = 10
219
- rect_start = (padding, padding)
220
- rect_end = (padding + text_w + padding, padding + text_h + padding)
221
- cv.rectangle(annotated_frame, rect_start, rect_end, last_status_color, -1)
222
- text_pos = (padding + 5, padding + text_h + 5)
223
- cv.putText(annotated_frame, last_status_text, text_pos, cv.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness)
224
-
225
- out.write(annotated_frame)
226
-
227
- cap.release()
228
- out.release()
229
- progress_bar.progress(100, text="Video processing completed!")
230
-
231
- st.success("Video processed successfully!")
232
-
233
- if os.path.exists(temp_output_path):
234
- with open(temp_output_path, "rb") as file:
235
- video_bytes = file.read()
236
-
237
- st.download_button(
238
- label="📥 Download Processed Video",
239
- data=video_bytes,
240
- file_name=f"distraction_detected_{uploaded_video.name}",
241
- mime="video/mp4",
242
- key="download_distraction_video"
243
- )
244
-
245
- st.subheader("Sample Frame from Processed Video")
246
- cap_out = cv.VideoCapture(temp_output_path)
247
- ret, frame = cap_out.read()
248
- if ret:
249
- frame_rgb = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
250
- st.image(frame_rgb, caption="Sample frame with distraction detection", use_container_width=True)
251
- cap_out.release()
252
-
253
- try:
254
- os.unlink(temp_input_path)
255
- if os.path.exists(temp_output_path): os.unlink(temp_output_path)
256
- except Exception as e:
257
- st.warning(f"Failed to clean up temporary files: {e}")
258
-
259
  # --- Feature: Real-time Drowsiness Detection ---
260
  elif page == "Real-time Drowsiness Detection":
261
  st.title("🧠 Real-time Drowsiness Detection")
262
- st.write("This will open your webcam and run the detection script.")
 
263
  if st.button("Start Drowsiness Detection"):
264
- with st.spinner("Launching webcam..."):
 
265
  subprocess.Popen(["python3", "drowsiness_detection.py", "--mode", "webcam"])
266
- st.success("Drowsiness detection started in a separate window. Press 'q' in that window to quit.")
 
 
267
 
268
  # --- Feature: Video Drowsiness Detection ---
269
  elif page == "Video Drowsiness Detection":
270
  st.title("📹 Video Drowsiness Detection")
271
- st.write("Upload a video file to detect drowsiness and download the processed video.")
272
  uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"])
 
273
  if uploaded_video is not None:
 
274
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
275
  tfile.write(uploaded_video.read())
276
  temp_input_path = tfile.name
277
  temp_output_path = tempfile.mktemp(suffix="_processed.mp4")
 
278
  st.subheader("Original Video Preview")
279
  st.video(uploaded_video)
 
280
  if st.button("Process Video for Drowsiness Detection"):
281
  progress_bar = st.progress(0, text="Preparing to process video...")
282
- with st.spinner("Processing video... This may take a while."):
283
- process = subprocess.Popen([
284
- "python3", "drowsiness_detection.py",
285
- "--mode", "video",
286
- "--input", temp_input_path,
287
- "--output", temp_output_path
288
- ], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
289
- stdout, stderr = process.communicate()
290
- if process.returncode == 0:
291
- progress_bar.progress(100, text="Video processing completed!")
292
- if os.path.exists(temp_output_path):
293
- st.success("Video processed successfully!")
294
- if stdout: st.code(stdout)
295
- with open(temp_output_path, "rb") as file: video_bytes = file.read()
296
- st.download_button(
297
- label="📥 Download Processed Video",
298
- data=video_bytes,
299
- file_name=f"drowsiness_detected_{uploaded_video.name}",
300
- mime="video/mp4",
301
- key="download_processed_video"
302
- )
303
- st.subheader("Sample Frame from Processed Video")
304
- cap = cv.VideoCapture(temp_output_path)
305
- ret, frame = cap.read()
306
- if ret: st.image(cv.cvtColor(frame, cv.COLOR_BGR2RGB), caption="Sample frame with drowsiness detection", use_container_width=True)
307
- cap.release()
308
- else:
309
- st.error("Error: Processed video file not found.")
310
- if stderr: st.code(stderr)
311
- else:
312
- st.error("An error occurred during video processing.")
313
- if stderr: st.code(stderr)
314
  try:
315
- if os.path.exists(temp_input_path): os.unlink(temp_input_path)
316
- if os.path.exists(temp_output_path): os.unlink(temp_output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
  except Exception as e:
318
- st.warning(f"Failed to clean up temporary files: {e}")
 
 
 
 
 
 
 
 
1
  import asyncio
2
  import sys
3
 
4
+ # --- Boilerplate for compatibility ---
5
  if sys.platform.startswith('linux') and sys.version_info >= (3, 8):
6
  try:
7
  asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
8
  except Exception:
9
  pass
10
+
11
  import streamlit as st
12
  from PIL import Image
13
  import numpy as np
 
17
  import os
18
  from ultralytics import YOLO
19
  import cv2 as cv
20
+
21
+ # --- NEW: Import your refactored video processing logic ---
22
+ from video_processor import process_video_with_progress
23
 
24
  model_path="best.pt"
25
 
 
35
  st.sidebar.title("🚗 Driver Distraction System")
36
  st.sidebar.write("Choose an option below:")
37
 
38
+ # --- Sidebar navigation ---
39
  page = st.sidebar.radio("Select Feature", [
40
+ "Distraction System",
41
+ "Video Drowsiness Detection",
42
+ "Real-time Drowsiness Detection"
43
  ])
44
 
45
  # --- Class Labels (for YOLO model) ---
46
+ st.sidebar.subheader("Class Names")
47
  class_names = ['drinking', 'hair and makeup', 'operating the radio', 'reaching behind',
48
  'safe driving', 'talking on the phone', 'talking to passenger', 'texting']
 
 
 
49
  for idx, class_name in enumerate(class_names):
50
  st.sidebar.write(f"{idx}: {class_name}")
51
 
 
88
  else:
89
  st.warning("No distractions detected.")
90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  # --- Feature: Real-time Drowsiness Detection ---
92
  elif page == "Real-time Drowsiness Detection":
93
  st.title("🧠 Real-time Drowsiness Detection")
94
+ st.info("This feature requires a local webcam and will open a new window.")
95
+ st.warning("This feature is intended for local use and will not function in the cloud deployment.")
96
  if st.button("Start Drowsiness Detection"):
97
+ try:
98
+ # This call is fine, as your new drowsiness_detection.py is set up to handle it.
99
  subprocess.Popen(["python3", "drowsiness_detection.py", "--mode", "webcam"])
100
+ st.success("Attempted to launch detection window. Please check your desktop.")
101
+ except Exception as e:
102
+ st.error(f"Failed to start process: {e}")
103
 
104
  # --- Feature: Video Drowsiness Detection ---
105
  elif page == "Video Drowsiness Detection":
106
  st.title("📹 Video Drowsiness Detection")
107
+ st.write("Upload a video file to detect drowsiness and generate a report.")
108
  uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"])
109
+
110
  if uploaded_video is not None:
111
+ # Create a temporary file to hold the uploaded video
112
  tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
113
  tfile.write(uploaded_video.read())
114
  temp_input_path = tfile.name
115
  temp_output_path = tempfile.mktemp(suffix="_processed.mp4")
116
+
117
  st.subheader("Original Video Preview")
118
  st.video(uploaded_video)
119
+
120
  if st.button("Process Video for Drowsiness Detection"):
121
  progress_bar = st.progress(0, text="Preparing to process video...")
122
+
123
+ # --- NEW: Define a callback function for the progress bar ---
124
+ def streamlit_progress_callback(current, total):
125
+ if total > 0:
126
+ percent_complete = int((current / total) * 100)
127
+ progress_bar.progress(percent_complete, text=f"Analyzing frame {current}/{total}...")
128
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  try:
130
+ with st.spinner("Processing video... This may take a while."):
131
+ # --- NEW: Directly call your robust video processing function ---
132
+ # No more complex subprocess logic needed!
133
+ stats = process_video_with_progress(
134
+ input_path=temp_input_path,
135
+ output_path=temp_output_path,
136
+ progress_callback=streamlit_progress_callback
137
+ )
138
+
139
+ progress_bar.progress(100, text="Video processing completed!")
140
+ st.success("Video processed successfully!")
141
+
142
+ # --- NEW: Display the returned statistics ---
143
+ st.subheader("Detection Results")
144
+ col1, col2, col3 = st.columns(3)
145
+ col1.metric("Drowsy Events", stats.get('drowsy_events', 0))
146
+ col2.metric("Yawn Events", stats.get('yawn_events', 0))
147
+ col3.metric("Head Down Events", stats.get('head_down_events', 0))
148
+
149
+ # Offer the processed video for download
150
+ if os.path.exists(temp_output_path):
151
+ with open(temp_output_path, "rb") as file:
152
+ video_bytes = file.read()
153
+ st.download_button(
154
+ label="📥 Download Processed Video",
155
+ data=video_bytes,
156
+ file_name=f"drowsiness_detected_{uploaded_video.name}",
157
+ mime="video/mp4"
158
+ )
159
  except Exception as e:
160
+ st.error(f"An error occurred during video processing: {e}")
161
+ finally:
162
+ # Cleanup temporary files
163
+ try:
164
+ if os.path.exists(temp_input_path): os.unlink(temp_input_path)
165
+ if os.path.exists(temp_output_path): os.unlink(temp_output_path)
166
+ except Exception as e_clean:
167
+ st.warning(f"Failed to clean up temporary files: {e_clean}")