Spaces:
Build error
Build error
import streamlit as st | |
from PIL import Image | |
import numpy as np | |
import subprocess | |
import time | |
import tempfile | |
import os | |
from ultralytics import YOLO | |
import cv2 as cv | |
import pandas as pd | |
model_path="/home/bacancy/POCs/Driver-Distraction-Detection-main/models/best2.pt" | |
# --- Page Configuration --- | |
st.set_page_config( | |
page_title="Driver Distraction System", | |
page_icon="π", | |
layout="wide", | |
initial_sidebar_state="expanded", | |
) | |
# --- Sidebar --- | |
st.sidebar.title("π Driver Distraction System") | |
st.sidebar.write("Choose an option below:") | |
# Sidebar navigation | |
page = st.sidebar.radio("Select Feature", [ | |
"Distraction System", | |
"Real-time Drowsiness Detection", | |
"Video Drowsiness Detection" | |
]) | |
# --- Class Labels (for YOLO model) --- | |
class_names = ['drinking', 'hair and makeup', 'operating the radio', 'reaching behind', | |
'safe driving', 'talking on the phone', 'talking to passenger', 'texting'] | |
# Sidebar Class Name Display | |
st.sidebar.subheader("Class Names") | |
for idx, class_name in enumerate(class_names): | |
st.sidebar.write(f"{idx}: {class_name}") | |
# --- Feature: YOLO Distraction Detection --- | |
if page == "Distraction System": | |
st.title("Driver Distraction System") | |
st.write("Upload an image or video to detect distractions using YOLO model.") | |
# File type selection | |
file_type = st.radio("Select file type:", ["Image", "Video"]) | |
if file_type == "Image": | |
uploaded_file = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file).convert('RGB') | |
image_np = np.array(image) | |
col1, col2 = st.columns([1, 1]) | |
with col1: | |
st.subheader("Uploaded Image") | |
st.image(image, caption="Original Image", use_container_width=True) | |
with col2: | |
st.subheader("Detection Results") | |
model = YOLO(model_path) | |
start_time = time.time() | |
results = model(image_np) | |
end_time = time.time() | |
prediction_time = end_time - start_time | |
result = results[0] | |
if len(result.boxes) > 0: | |
boxes = result.boxes | |
confidences = boxes.conf.cpu().numpy() | |
classes = boxes.cls.cpu().numpy() | |
class_names_dict = result.names | |
max_conf_idx = confidences.argmax() | |
predicted_class = class_names_dict[int(classes[max_conf_idx])] | |
confidence_score = confidences[max_conf_idx] | |
st.markdown(f"### Predicted Class: **{predicted_class}**") | |
st.markdown(f"### Confidence Score: **{confidence_score:.4f}** ({confidence_score*100:.1f}%)") | |
st.markdown(f"Inference Time: {prediction_time:.2f} seconds") | |
else: | |
st.warning("No distractions detected.") | |
else: # Video processing | |
uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"]) | |
if uploaded_video is not None: | |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
tfile.write(uploaded_video.read()) | |
temp_input_path = tfile.name | |
temp_output_path = tempfile.mktemp(suffix="_distraction_detected.mp4") | |
st.subheader("Video Information") | |
cap = cv.VideoCapture(temp_input_path) | |
fps = cap.get(cv.CAP_PROP_FPS) | |
width = int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) | |
total_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT)) | |
duration = total_frames / fps if fps > 0 else 0 | |
cap.release() | |
col1, col2 = st.columns(2) | |
with col1: | |
st.metric("Duration", f"{duration:.2f} seconds") | |
st.metric("Original FPS", f"{fps:.2f}") | |
with col2: | |
st.metric("Resolution", f"{width}x{height}") | |
st.metric("Total Frames", total_frames) | |
st.subheader("Original Video Preview") | |
st.video(uploaded_video) | |
if st.button("Process Video for Distraction Detection"): | |
TARGET_PROCESSING_FPS = 10 | |
# --- NEW: Hyperparameter for the temporal smoothing logic --- | |
PERSISTENCE_CONFIDENCE_THRESHOLD = 0.40 # Stick with old class if found with >= 40% confidence | |
st.info(f"π For faster results, video will be processed at ~{TARGET_PROCESSING_FPS} FPS.") | |
st.info(f"π§ Applying temporal smoothing to reduce status flickering (Persistence Threshold: {PERSISTENCE_CONFIDENCE_THRESHOLD*100:.0f}%).") | |
progress_bar = st.progress(0, text="Starting video processing...") | |
with st.spinner(f"Processing video... This may take a while."): | |
model = YOLO(model_path) | |
cap = cv.VideoCapture(temp_input_path) | |
fourcc = cv.VideoWriter_fourcc(*'mp4v') | |
out = cv.VideoWriter(temp_output_path, fourcc, fps, (width, height)) | |
frame_skip_interval = max(1, round(fps / TARGET_PROCESSING_FPS)) | |
frame_count = 0 | |
last_best_box_coords = None | |
last_best_box_label = "" | |
last_status_text = "Status: Initializing..." | |
last_status_color = (128, 128, 128) | |
# --- NEW: State variable to store the last confirmed class --- | |
last_confirmed_class_name = 'safe driving' | |
while cap.isOpened(): | |
ret, frame = cap.read() | |
if not ret: | |
break | |
frame_count += 1 | |
progress = int((frame_count / total_frames) * 100) if total_frames > 0 else 0 | |
progress_bar.progress(progress, text=f"Analyzing frame {frame_count}/{total_frames}") | |
annotated_frame = frame.copy() | |
if frame_count % frame_skip_interval == 0: | |
results = model(annotated_frame) | |
result = results[0] | |
last_best_box_coords = None # Reset box for this processing cycle | |
if len(result.boxes) > 0: | |
boxes = result.boxes | |
class_names_dict = result.names | |
confidences = boxes.conf.cpu().numpy() | |
classes = boxes.cls.cpu().numpy() | |
# --- NEW STABILITY LOGIC --- | |
final_box_to_use = None | |
# 1. Check if the last known class exists with reasonable confidence | |
for i in range(len(boxes)): | |
current_class_name = class_names_dict[int(classes[i])] | |
if current_class_name == last_confirmed_class_name and confidences[i] >= PERSISTENCE_CONFIDENCE_THRESHOLD: | |
final_box_to_use = boxes[i] | |
break | |
# 2. If not, fall back to the highest confidence detection in the current frame | |
if final_box_to_use is None: | |
max_conf_idx = confidences.argmax() | |
final_box_to_use = boxes[max_conf_idx] | |
# --- END OF NEW LOGIC --- | |
# Now, process the determined "final_box_to_use" | |
x1, y1, x2, y2 = final_box_to_use.xyxy[0].cpu().numpy() | |
confidence = final_box_to_use.conf[0].cpu().numpy() | |
class_id = int(final_box_to_use.cls[0].cpu().numpy()) | |
class_name = class_names_dict[class_id] | |
# Update the state for the next frames | |
last_confirmed_class_name = class_name | |
last_best_box_coords = (int(x1), int(y1), int(x2), int(y2)) | |
last_best_box_label = f"{class_name}: {confidence:.2f}" | |
if class_name != 'safe driving': | |
last_status_text = f"Status: {class_name.replace('_', ' ').title()}" | |
last_status_color = (0, 0, 255) | |
else: | |
last_status_text = "Status: Safe Driving" | |
last_status_color = (0, 128, 0) | |
else: | |
# No detections, reset to safe driving | |
last_confirmed_class_name = 'safe driving' | |
last_status_text = "Status: Safe Driving" | |
last_status_color = (0, 128, 0) | |
# Draw annotations on EVERY frame using the last known data | |
if last_best_box_coords: | |
cv.rectangle(annotated_frame, (last_best_box_coords[0], last_best_box_coords[1]), | |
(last_best_box_coords[2], last_best_box_coords[3]), (0, 255, 0), 2) | |
cv.putText(annotated_frame, last_best_box_label, | |
(last_best_box_coords[0], last_best_box_coords[1] - 10), | |
cv.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2) | |
# Draw status text | |
font_scale, font_thickness = 1.0, 2 | |
(text_w, text_h), _ = cv.getTextSize(last_status_text, cv.FONT_HERSHEY_SIMPLEX, font_scale, font_thickness) | |
padding = 10 | |
rect_start = (padding, padding) | |
rect_end = (padding + text_w + padding, padding + text_h + padding) | |
cv.rectangle(annotated_frame, rect_start, rect_end, last_status_color, -1) | |
text_pos = (padding + 5, padding + text_h + 5) | |
cv.putText(annotated_frame, last_status_text, text_pos, cv.FONT_HERSHEY_SIMPLEX, font_scale, (255, 255, 255), font_thickness) | |
out.write(annotated_frame) | |
cap.release() | |
out.release() | |
progress_bar.progress(100, text="Video processing completed!") | |
st.success("Video processed successfully!") | |
if os.path.exists(temp_output_path): | |
with open(temp_output_path, "rb") as file: | |
video_bytes = file.read() | |
st.download_button( | |
label="π₯ Download Processed Video", | |
data=video_bytes, | |
file_name=f"distraction_detected_{uploaded_video.name}", | |
mime="video/mp4", | |
key="download_distraction_video" | |
) | |
st.subheader("Sample Frame from Processed Video") | |
cap_out = cv.VideoCapture(temp_output_path) | |
ret, frame = cap_out.read() | |
if ret: | |
frame_rgb = cv.cvtColor(frame, cv.COLOR_BGR2RGB) | |
st.image(frame_rgb, caption="Sample frame with distraction detection", use_container_width=True) | |
cap_out.release() | |
try: | |
os.unlink(temp_input_path) | |
if os.path.exists(temp_output_path): os.unlink(temp_output_path) | |
except Exception as e: | |
st.warning(f"Failed to clean up temporary files: {e}") | |
# --- Feature: Real-time Drowsiness Detection --- | |
elif page == "Real-time Drowsiness Detection": | |
st.title("π§ Real-time Drowsiness Detection") | |
st.write("This will open your webcam and run the detection script.") | |
if st.button("Start Drowsiness Detection"): | |
with st.spinner("Launching webcam..."): | |
subprocess.Popen(["python3", "drowsiness_detection.py", "--mode", "webcam"]) | |
st.success("Drowsiness detection started in a separate window. Press 'q' in that window to quit.") | |
# --- Feature: Video Drowsiness Detection --- | |
elif page == "Video Drowsiness Detection": | |
st.title("πΉ Video Drowsiness Detection") | |
st.write("Upload a video file to detect drowsiness and download the processed video.") | |
uploaded_video = st.file_uploader("Upload Video", type=["mp4", "avi", "mov", "mkv", "webm"]) | |
if uploaded_video is not None: | |
tfile = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") | |
tfile.write(uploaded_video.read()) | |
temp_input_path = tfile.name | |
temp_output_path = tempfile.mktemp(suffix="_processed.mp4") | |
st.subheader("Original Video Preview") | |
st.video(uploaded_video) | |
if st.button("Process Video for Drowsiness Detection"): | |
progress_bar = st.progress(0, text="Preparing to process video...") | |
with st.spinner("Processing video... This may take a while."): | |
process = subprocess.Popen([ | |
"python3", "drowsiness_detection.py", | |
"--mode", "video", | |
"--input", temp_input_path, | |
"--output", temp_output_path | |
], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True) | |
stdout, stderr = process.communicate() | |
if process.returncode == 0: | |
progress_bar.progress(100, text="Video processing completed!") | |
if os.path.exists(temp_output_path): | |
st.success("Video processed successfully!") | |
if stdout: st.code(stdout) | |
with open(temp_output_path, "rb") as file: video_bytes = file.read() | |
st.download_button( | |
label="π₯ Download Processed Video", | |
data=video_bytes, | |
file_name=f"drowsiness_detected_{uploaded_video.name}", | |
mime="video/mp4", | |
key="download_processed_video" | |
) | |
st.subheader("Sample Frame from Processed Video") | |
cap = cv.VideoCapture(temp_output_path) | |
ret, frame = cap.read() | |
if ret: st.image(cv.cvtColor(frame, cv.COLOR_BGR2RGB), caption="Sample frame with drowsiness detection", use_container_width=True) | |
cap.release() | |
else: | |
st.error("Error: Processed video file not found.") | |
if stderr: st.code(stderr) | |
else: | |
st.error("An error occurred during video processing.") | |
if stderr: st.code(stderr) | |
try: | |
if os.path.exists(temp_input_path): os.unlink(temp_input_path) | |
if os.path.exists(temp_output_path): os.unlink(temp_output_path) | |
except Exception as e: | |
st.warning(f"Failed to clean up temporary files: {e}") |