File size: 5,982 Bytes
5101617
 
 
95b307f
 
5101617
 
 
95b307f
 
 
 
 
 
5101617
 
 
 
 
 
 
 
 
 
 
 
 
 
 
817a521
5101617
 
817a521
 
95b307f
 
 
 
 
 
 
817a521
95b307f
 
 
 
 
 
5101617
95b307f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5101617
95b307f
 
 
 
 
 
 
 
 
 
5101617
95b307f
5101617
 
95b307f
 
5101617
 
95b307f
5101617
 
 
 
 
 
95b307f
5101617
 
 
 
 
 
95b307f
5101617
 
95b307f
 
 
 
 
 
 
 
 
5101617
 
 
95b307f
8977de8
95b307f
 
 
5101617
8977de8
95b307f
8977de8
95b307f
8977de8
 
95b307f
 
 
 
 
 
 
 
5101617
8977de8
 
 
 
 
5101617
95b307f
5101617
8977de8
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# app_gradio.py
import gradio as gr
import numpy as np
import torch
import soundfile as sf
import os
import yaml
from dotenv import load_dotenv
from threading import Thread

# --- TTS & AI Imports ---
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
from streamer import ParlerTTSStreamer # Make sure streamer.py is available

from src.detection.factory import get_detector
from src.alerting.alert_system import get_alerter

# --- Load Configuration and Environment Variables ---
# This part is the same as our Streamlit app
load_dotenv()
config_path = 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)
secrets = {
    "gemini_api_key": os.getenv("GEMINI_API_KEY"),
}

# --- Initialize Backend Components ---
print("Initializing detector and alerter...")
detector = get_detector(config)
alerter = get_alerter(config, secrets["gemini_api_key"])
print("Initialization complete. Launching UI...")

# --- Parler-TTS Model Setup (Requires GPU) ---
print("Loading Parler-TTS model. This may take a moment...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
torch_dtype = torch.float16 if device != "cpu" else torch.float32


# Using a smaller, faster model suitable for real-time alerts
repo_id = "parler-tts/parler_tts_mini_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
print("Parler-TTS model loaded.")

# --- Audio Streaming Generator Function ---
def stream_alert_audio(text_prompt):
    """
    A generator function that yields audio chunks for a given text prompt.
    This is the core of the streaming implementation.
    """
    sampling_rate = model.config.sampling_rate
    description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
    
    prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
    description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)

    # Setup the streamer
    streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
    
    generation_kwargs = dict(
        input_ids=description_ids,
        prompt_input_ids=prompt_ids,
        streamer=streamer,
        do_sample=True,
        temperature=1.0, # Increase for more vocal variety
        repetition_penalty=1.2,
    )
    
    # Run generation in a separate thread to not block the UI
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    
    try:
        thread.start()
        print(f"Audio stream started for: '{text_prompt}'")
        # Yield audio chunks as they become available
        for new_audio_chunk in streamer:
            yield (sampling_rate, new_audio_chunk)
    finally:
        # CRITICAL: This block runs after the generator is exhausted (audio finishes)
        # We reset the alerter state so that a new alert can be triggered later.
        print("Audio stream finished. Resetting alerter state.")
        alerter.reset_alert()
        
# --- Main Webcam Processing Function ---
def process_live_frame(frame):
    """
    Processes each webcam frame, performs drowsiness detection, and
    returns a generator for audio streaming when an alert is triggered.
    """
    if frame is None:
        return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None

    processed_frame, indicators, _ = detector.process_frame(frame)
    drowsiness_level = indicators.get("drowsiness_level", "Awake")
    lighting = indicators.get("lighting", "Good")
    score = indicators.get("details", {}).get("Score", 0)

    # Build status text
    status_text = f"Lighting: {lighting}\n"
    if lighting == "Low":
        status_text += "Detection paused due to low light."
    else:
        status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"

    # --- Alert Trigger Logic ---
    audio_output = None
    if drowsiness_level != "Awake":
        # alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
        alert_text = alerter.trigger_alert(level=drowsiness_level)
        if alert_text:
            # If we got text, it means we can start an alert.
            # We return the generator function itself. Gradio will handle it.
            audio_output = stream_alert_audio(alert_text)
            
    # On subsequent frames where the user is drowsy, trigger_alert() will return None
    # due to the cooldown, preventing a new stream from starting, which is what we want.

    return processed_frame, status_text, audio_output


# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
    gr.Markdown("# πŸš— Drive Paddy - Drowsiness Detection (Streaming)")
    gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")

    with gr.Row():
        with gr.Column(scale=2):
            webcam_input = gr.Image(sources=["webcam"], streaming=True, label="Live Camera Feed")
        with gr.Column(scale=1):
            processed_output = gr.Image(label="Processed Feed")
            status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
            
            # --- KEY CHANGE: The Audio component now uses streaming=True ---
            audio_alert_output = gr.Audio(
                label="Alert System", 
                autoplay=True, 
                visible=False, # Hide the player controls
                streaming=True
            )

    webcam_input.stream(
        fn=process_live_frame,
        inputs=[webcam_input],
        outputs=[processed_output, status_output, audio_alert_output]
    )


# --- Launch the App ---
if __name__ == "__main__":
    app.launch(debug=True)