Spaces:
Sleeping
Sleeping
File size: 6,540 Bytes
5101617 95b307f 5101617 6c72f8e 5101617 95b307f fcccf01 95b307f 5101617 817a521 5101617 817a521 fcccf01 ebefc6a fcccf01 95b307f 817a521 95b307f 5101617 95b307f fcccf01 95b307f 5101617 95b307f 5101617 95b307f fcccf01 5101617 95b307f 5101617 95b307f 5101617 95b307f 5101617 95b307f 5101617 95b307f 963795e 95b307f 5101617 95b307f 8977de8 95b307f 5101617 8977de8 95b307f fcccf01 95b307f 8977de8 95b307f fcccf01 95b307f 5101617 8977de8 fcccf01 8977de8 5101617 95b307f 5101617 8977de8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
# app_gradio.py
import gradio as gr
import numpy as np
import torch
import soundfile as sf
import os
import yaml
import spaces
from dotenv import load_dotenv
from threading import Thread
from gradio_webrtc import WebRTC
from twilio.rest import Client
# --- TTS & AI Imports ---
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from streamer import ParlerTTSStreamer # Make sure streamer.py is available
from src.detection.factory import get_detector
from src.alerting.alert_system import get_alerter
# --- Load Configuration and Environment Variables ---
# This part is the same as our Streamlit app
load_dotenv()
config_path = 'config.yaml'
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
secrets = {
"gemini_api_key": os.getenv("GEMINI_API_KEY"),
}
# --- Initialize Backend Components ---
print("Initializing detector and alerter...")
detector = get_detector(config)
alerter = get_alerter(config, secrets["gemini_api_key"])
print("Initialization complete. Launching UI...")
account_sid = os.environ.get("TURN_USERNAME")
auth_token = os.environ.get("TURN_CREDENTIAL")
if account_sid and auth_token:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {
"iceServers": token.ice_servers,
"iceTransportPolicy": "relay",
}
else:
rtc_configuration = None
# --- Parler-TTS Model Setup (Requires GPU) ---
print("Loading Parler-TTS model. This may take a moment...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
torch_dtype = torch.float16 if device != "cpu" else torch.float32
# Using a smaller, faster model suitable for real-time alerts
repo_id = "parler-tts/parler_tts_mini_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
print("Parler-TTS model loaded.")
# --- Audio Streaming Generator Function ---
@spaces.GPU
def stream_alert_audio(text_prompt):
"""
A generator function that yields audio chunks for a given text prompt.
This is the core of the streaming implementation.
"""
sampling_rate = model.config.sampling_rate
description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
# Setup the streamer
streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
generation_kwargs = dict(
input_ids=description_ids,
prompt_input_ids=prompt_ids,
streamer=streamer,
do_sample=True,
temperature=1.0, # Increase for more vocal variety
repetition_penalty=1.2,
)
# Run generation in a separate thread to not block the UI
thread = Thread(target=model.generate, kwargs=generation_kwargs)
try:
thread.start()
print(f"Audio stream started for: '{text_prompt}'")
# Yield audio chunks as they become available
for new_audio_chunk in streamer:
yield (sampling_rate, new_audio_chunk)
finally:
# CRITICAL: This block runs after the generator is exhausted (audio finishes)
# We reset the alerter state so that a new alert can be triggered later.
print("Audio stream finished. Resetting alerter state.")
alerter.reset_alert()
# --- Main Webcam Processing Function ---
@spaces.GPU
def process_live_frame(frame):
"""
Processes each webcam frame, performs drowsiness detection, and
returns a generator for audio streaming when an alert is triggered.
"""
if frame is None:
return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None
processed_frame, indicators, _ = detector.process_frame(frame)
drowsiness_level = indicators.get("drowsiness_level", "Awake")
lighting = indicators.get("lighting", "Good")
score = indicators.get("details", {}).get("Score", 0)
# Build status text
status_text = f"Lighting: {lighting}\n"
if lighting == "Low":
status_text += "Detection paused due to low light."
else:
status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
# --- Alert Trigger Logic ---
audio_output = None
if drowsiness_level != "Awake":
# alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
alert_text = alerter.trigger_alert(level=drowsiness_level)
if alert_text:
# If we got text, it means we can start an alert.
# We return the generator function itself. Gradio will handle it.
audio_output = stream_alert_audio(alert_text)
else:
alert_text = "WAKE UP"
audio_output = stream_alert_audio(alert_text)
# On subsequent frames where the user is drowsy, trigger_alert() will return None
# due to the cooldown, preventing a new stream from starting, which is what we want.
return processed_frame, status_text, audio_output
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
gr.Markdown("# π Drive Paddy - Drowsiness Detection (Streaming)")
gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")
with gr.Row():
with gr.Column(scale=2):
webcam_input = WebRTC(label="Stream", rtc_configuration=rtc_configuration)
with gr.Column(scale=1):
processed_output = gr.Image(label="Processed Feed")
status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
# --- KEY CHANGE: The Audio component now uses streaming=True ---
audio_alert_output = gr.Audio(
label="Alert System",
autoplay=True,
visible=True, # Hide the player controls
streaming=True
)
webcam_input.stream(
fn=process_live_frame,
inputs=[webcam_input],
outputs=[status_output, audio_alert_output],
time_limit=10
)
# --- Launch the App ---
if __name__ == "__main__":
app.launch(debug=True) |