Spaces:
Sleeping
Sleeping
File size: 6,388 Bytes
5101617 95b307f 5101617 95b307f fcccf01 95b307f b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 817a521 5101617 b6b2705 817a521 b6b2705 fcccf01 ebefc6a b6b2705 fcccf01 b6b2705 95b307f b6b2705 95b307f 5101617 b6b2705 95b307f b6b2705 95b307f b6b2705 95b307f b6b2705 95b307f 5101617 95b307f b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 b6b2705 5101617 95b307f 963795e b6b2705 95b307f 8977de8 95b307f b6b2705 5101617 8977de8 95b307f b6b2705 95b307f 8977de8 b6b2705 95b307f b6b2705 8977de8 5101617 b6b2705 95b307f 8977de8 b6b2705 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
# app_gradio.py
import gradio as gr
import numpy as np
import torch
import os
import yaml
from dotenv import load_dotenv
from threading import Thread
from gradio_webrtc import WebRTC
from twilio.rest import Client
# --- TTS & AI Imports ---
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
from streamer import ParlerTTSStreamer
# --- Local Project Imports ---
from src.detection.factory import get_detector
from src.alerting.alert_system import get_alerter
# --- Load Configuration and Environment Variables ---
load_dotenv()
config_path = 'config.yaml'
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}
# --- Initialize Backend Components ---
print("Initializing detector and alerter...")
detector = get_detector(config)
alerter = get_alerter(config, secrets["gemini_api_key"])
print("Initialization complete.")
# --- Twilio TURN Server Setup ---
account_sid = os.environ.get("TURN_USERNAME")
auth_token = os.environ.get("TURN_CREDENTIAL")
rtc_configuration = None
if account_sid and auth_token:
try:
client = Client(account_sid, auth_token)
token = client.tokens.create()
rtc_configuration = {"iceServers": token.ice_servers}
print("Twilio TURN server configured successfully.")
except Exception as e:
print(f"Warning: Failed to create Twilio token. Using public STUN server. Error: {e}")
# Fallback to a public STUN server if Twilio fails or is not configured
if rtc_configuration is None:
print("Using public STun server.")
rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
# --- Parler-TTS Model Setup ---
print("Loading Parler-TTS model...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
print("\nWARNING: Running Parler-TTS on a CPU is slow. A GPU is highly recommended.\n")
torch_dtype = torch.float16 if device != "cpu" else torch.float32
repo_id = "parler-tts/parler_tts_mini_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
print("Parler-TTS model loaded.")
# --- Audio Streaming Generator ---
def stream_alert_audio(text_prompt):
"""A generator that streams audio chunks for a given text prompt."""
sampling_rate = model.config.sampling_rate
description = "A female speaker with a clear and urgent voice."
prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
generation_kwargs = dict(input_ids=description_ids, prompt_input_ids=prompt_ids, streamer=streamer, do_sample=True, temperature=1.0, repetition_penalty=1.2)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
try:
thread.start()
print(f"Audio stream started for: '{text_prompt}'")
for new_audio_chunk in streamer:
yield (sampling_rate, new_audio_chunk)
finally:
print("Audio stream finished. Resetting alerter state.")
alerter.reset_alert()
# --- Decoupled Processing Functions ---
def process_video_and_update_state(frame_dict: dict, state: dict):
"""
HIGH-FREQUENCY LOOP: Processes video, updates shared state, and returns the processed frame.
This function's speed directly impacts video latency.
"""
if not frame_dict or "video" not in frame_dict or frame_dict["video"] is None:
return np.zeros((480, 640, 3), dtype=np.uint8), state
frame = frame_dict["video"]
processed_frame, indicators, _ = detector.process_frame(frame)
state['indicators'] = indicators
return processed_frame, state
def update_ui_from_state(state: dict):
"""
LOW-FREQUENCY LOOP: Reads from state to update status text and trigger audio.
This runs independently of the video loop.
"""
indicators = state.get('indicators', {})
drowsiness_level = indicators.get("drowsiness_level", "Awake")
lighting = indicators.get("lighting", "Good")
score = indicators.get("details", {}).get("Score", 0)
status_text = f"Lighting: {lighting}\nStatus: {drowsiness_level}\nScore: {score:.2f}"
if lighting == "Low":
status_text = "Lighting: Low\nDetection paused due to low light."
audio_output = None
if drowsiness_level != "Awake":
alert_text = alerter.trigger_alert(level=drowsiness_level)
if alert_text:
audio_output = stream_alert_audio(alert_text)
return status_text, audio_output
# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
gr.Markdown("# π Drive Paddy - Drowsiness Detection (WebRTC)")
gr.Markdown("Low-latency video processing via WebRTC, with decoupled UI updates for smooth performance.")
# Shared state object to pass data between the two processing loops
shared_state = gr.State(value={'indicators': {}})
with gr.Row():
with gr.Column(scale=2):
# This WebRTC component is now correctly used for both input and output of the video stream.
webcam = WebRTC(label="Live Camera Feed", rtc_configuration=rtc_configuration)
with gr.Column(scale=1):
status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
audio_alert_output = gr.Audio(label="Alert System", autoplay=True, visible=False, streaming=True)
# LOOP 1: High-Frequency Video Stream (as fast as possible)
# This takes video from the webcam, processes it, and sends it right back.
webcam.stream(
fn=process_video_and_update_state,
inputs=[webcam, shared_state],
outputs=[webcam, shared_state],
)
# LOOP 2: Low-Frequency UI Updates (4 times per second)
# This runs on a timer, reads the shared state, and updates the other UI elements.
app.load(
fn=update_ui_from_state,
inputs=[shared_state],
outputs=[status_output, audio_alert_output],
)
if __name__ == "__main__":
print("Starting Drive Paddy WebRTC Application...")
app.launch(debug=True)
|