Spaces:
Runtime error
Runtime error
# app_webrtc.py | |
import gradio as gr | |
import numpy as np | |
import os | |
import yaml | |
from dotenv import load_dotenv | |
import io | |
from scipy.io.wavfile import read as read_wav | |
from pydub import AudioSegment | |
import cv2 | |
import time | |
from gradio_webrtc import WebRTC | |
# Correctly import from the drive_paddy package structure | |
from src.detection.factory import get_detector | |
from src.alerting.alert_system import get_alerter | |
# --- Load Configuration and Environment Variables --- | |
load_dotenv() | |
config_path = 'config.yaml' | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
secrets = { | |
"gemini_api_key": os.getenv("GEMINI_API_KEY"), | |
} | |
# --- Initialize Backend Components --- | |
detector = get_detector(config) | |
alerter = get_alerter(config, secrets["gemini_api_key"]) | |
geo_settings = config.get('geometric_settings', {}) | |
drowsiness_levels = geo_settings.get('drowsiness_levels', {}) | |
SLIGHTLY_DROWSY_DEFAULT = drowsiness_levels.get('slightly_drowsy_threshold', 0.3) | |
VERY_DROWSY_DEFAULT = drowsiness_levels.get('very_drowsy_threshold', 0.8) | |
# --- Global state for audio (simpler than queues for this component) --- | |
# We use a global variable to hold the audio data, which the UI will poll. | |
# This is a common pattern in simple Gradio streaming apps. | |
latest_audio_alert = None | |
# --- Main Processing Function --- | |
def process_stream(frame: np.ndarray, sensitivity_threshold: float) -> np.ndarray: | |
""" | |
This is the core function. It takes a frame and returns the processed frame. | |
All logic, including status drawing and alert triggering, happens here. | |
""" | |
global latest_audio_alert | |
if frame is None: | |
return np.zeros((480, 640, 3), dtype=np.uint8) | |
# Process the frame using our existing detector. | |
# The detector already draws landmarks and status overlays. | |
processed_frame, indicators, _ = detector.process_frame(frame) | |
drowsiness_level = indicators.get("drowsiness_level", "Awake") | |
# Handle audio alerts | |
if drowsiness_level != "Awake": | |
audio_data = alerter.trigger_alert(level=drowsiness_level) | |
if audio_data: | |
# Convert audio for Gradio and store it in the global variable | |
try: | |
byte_io = io.BytesIO(audio_data) | |
audio = AudioSegment.from_mp3(byte_io) | |
wav_byte_io = io.BytesIO() | |
audio.export(wav_byte_io, format="wav") | |
wav_byte_io.seek(0) | |
sample_rate, data = read_wav(wav_byte_io) | |
latest_audio_alert = (sample_rate, data) | |
except Exception as e: | |
print(f"Audio processing error: {e}") | |
latest_audio_alert = None | |
else: | |
alerter.reset_alert() | |
return processed_frame | |
# --- Function to check for and return audio alerts --- | |
def get_audio_update(): | |
""" | |
This function is polled by the UI to check for new audio alerts. | |
""" | |
global latest_audio_alert | |
if latest_audio_alert: | |
audio_to_play = latest_audio_alert | |
latest_audio_alert = None # Clear the alert after sending it | |
return audio_to_play | |
return None | |
# --- Gradio UI Definition --- | |
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue", secondary_hue="blue")) as app: | |
gr.HTML( | |
""" | |
<div align="center"> | |
<img src="https://em-content.zobj.net/source/samsung/380/automobile_1f697.png" alt="Car Emoji" width="100"/> | |
<h1>Drive Paddyn</h1> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
# The WebRTC component now directly shows the processed output | |
webrtc_output = WebRTC( | |
label="Live Detection Feed", | |
video_source="webcam", | |
) | |
with gr.Row(): | |
sensitivity_slider = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=SLIGHTLY_DROWSY_DEFAULT, | |
step=0.05, | |
label="Alert Sensitivity Threshold", | |
info="Lower value = more sensitive to drowsiness signs." | |
) | |
# Hidden audio component for playing alerts | |
audio_player = gr.Audio(autoplay=True, visible=False) | |
# Connect the WebRTC stream to the processing function | |
webrtc_output.stream( | |
fn=process_stream, | |
inputs=[webrtc_output, sensitivity_slider], | |
outputs=[webrtc_output], | |
# The 'every' parameter is not needed for this component; it streams as fast as possible. | |
) | |
# Use a separate loop to poll for audio updates. | |
# This is more stable than returning multiple values in a high-frequency stream. | |
app.load( | |
fn=get_audio_update, | |
inputs=None, | |
outputs=[audio_player], | |
every=1 # Check for a new audio alert every 1 second | |
) | |
# --- Launch the App --- | |
if __name__ == "__main__": | |
app.launch(debug=True) | |