File size: 6,388 Bytes
5101617
 
 
95b307f
5101617
 
 
95b307f
fcccf01
 
95b307f
 
 
b6b2705
 
5101617
b6b2705
5101617
 
 
 
 
 
 
 
b6b2705
5101617
 
817a521
5101617
 
b6b2705
817a521
b6b2705
fcccf01
ebefc6a
b6b2705
fcccf01
b6b2705
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b307f
 
b6b2705
95b307f
 
 
 
 
 
 
5101617
b6b2705
95b307f
b6b2705
95b307f
b6b2705
95b307f
 
 
b6b2705
95b307f
5101617
95b307f
 
 
 
 
 
 
b6b2705
 
 
 
5101617
b6b2705
 
5101617
b6b2705
 
5101617
b6b2705
5101617
b6b2705
 
 
 
 
 
 
 
 
5101617
 
 
 
b6b2705
5101617
b6b2705
5101617
 
 
95b307f
 
963795e
b6b2705
95b307f
8977de8
95b307f
b6b2705
 
 
 
 
5101617
8977de8
95b307f
b6b2705
 
95b307f
8977de8
b6b2705
95b307f
b6b2705
 
 
 
 
 
8977de8
5101617
b6b2705
 
 
 
 
 
 
95b307f
8977de8
b6b2705
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
# app_gradio.py
import gradio as gr
import numpy as np
import torch
import os
import yaml
from dotenv import load_dotenv
from threading import Thread
from gradio_webrtc import WebRTC
from twilio.rest import Client

# --- TTS & AI Imports ---
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
from streamer import ParlerTTSStreamer 

# --- Local Project Imports ---
from src.detection.factory import get_detector
from src.alerting.alert_system import get_alerter

# --- Load Configuration and Environment Variables ---
load_dotenv()
config_path = 'config.yaml'
with open(config_path, 'r') as f:
    config = yaml.safe_load(f)
secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}

# --- Initialize Backend Components ---
print("Initializing detector and alerter...")
detector = get_detector(config)
alerter = get_alerter(config, secrets["gemini_api_key"])
print("Initialization complete.")

# --- Twilio TURN Server Setup ---
account_sid = os.environ.get("TURN_USERNAME")
auth_token = os.environ.get("TURN_CREDENTIAL")
rtc_configuration = None
if account_sid and auth_token:
    try:
        client = Client(account_sid, auth_token)
        token = client.tokens.create()
        rtc_configuration = {"iceServers": token.ice_servers}
        print("Twilio TURN server configured successfully.")
    except Exception as e:
        print(f"Warning: Failed to create Twilio token. Using public STUN server. Error: {e}")
# Fallback to a public STUN server if Twilio fails or is not configured
if rtc_configuration is None:
    print("Using public STun server.")
    rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}


# --- Parler-TTS Model Setup ---
print("Loading Parler-TTS model...")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("\nWARNING: Running Parler-TTS on a CPU is slow. A GPU is highly recommended.\n")
torch_dtype = torch.float16 if device != "cpu" else torch.float32

repo_id = "parler-tts/parler_tts_mini_v0.1"
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
print("Parler-TTS model loaded.")

# --- Audio Streaming Generator ---
def stream_alert_audio(text_prompt):
    """A generator that streams audio chunks for a given text prompt."""
    sampling_rate = model.config.sampling_rate
    description = "A female speaker with a clear and urgent voice."
    prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
    description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
    streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
    generation_kwargs = dict(input_ids=description_ids, prompt_input_ids=prompt_ids, streamer=streamer, do_sample=True, temperature=1.0, repetition_penalty=1.2)
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    try:
        thread.start()
        print(f"Audio stream started for: '{text_prompt}'")
        for new_audio_chunk in streamer:
            yield (sampling_rate, new_audio_chunk)
    finally:
        print("Audio stream finished. Resetting alerter state.")
        alerter.reset_alert()

# --- Decoupled Processing Functions ---

def process_video_and_update_state(frame_dict: dict, state: dict):
    """
    HIGH-FREQUENCY LOOP: Processes video, updates shared state, and returns the processed frame.
    This function's speed directly impacts video latency.
    """
    if not frame_dict or "video" not in frame_dict or frame_dict["video"] is None:
        return np.zeros((480, 640, 3), dtype=np.uint8), state

    frame = frame_dict["video"]
    processed_frame, indicators, _ = detector.process_frame(frame)
    state['indicators'] = indicators
    return processed_frame, state

def update_ui_from_state(state: dict):
    """
    LOW-FREQUENCY LOOP: Reads from state to update status text and trigger audio.
    This runs independently of the video loop.
    """
    indicators = state.get('indicators', {})
    drowsiness_level = indicators.get("drowsiness_level", "Awake")
    lighting = indicators.get("lighting", "Good")
    score = indicators.get("details", {}).get("Score", 0)

    status_text = f"Lighting: {lighting}\nStatus: {drowsiness_level}\nScore: {score:.2f}"
    if lighting == "Low":
        status_text = "Lighting: Low\nDetection paused due to low light."

    audio_output = None
    if drowsiness_level != "Awake":
        alert_text = alerter.trigger_alert(level=drowsiness_level)
        if alert_text:
            audio_output = stream_alert_audio(alert_text)
    return status_text, audio_output

# --- Gradio UI Definition ---
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
    gr.Markdown("# πŸš— Drive Paddy - Drowsiness Detection (WebRTC)")
    gr.Markdown("Low-latency video processing via WebRTC, with decoupled UI updates for smooth performance.")
    
    # Shared state object to pass data between the two processing loops
    shared_state = gr.State(value={'indicators': {}})

    with gr.Row():
        with gr.Column(scale=2):
            # This WebRTC component is now correctly used for both input and output of the video stream.
            webcam = WebRTC(label="Live Camera Feed", rtc_configuration=rtc_configuration)
        with gr.Column(scale=1):
            status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
            audio_alert_output = gr.Audio(label="Alert System", autoplay=True, visible=False, streaming=True)
            
    # LOOP 1: High-Frequency Video Stream (as fast as possible)
    # This takes video from the webcam, processes it, and sends it right back.
    webcam.stream(
        fn=process_video_and_update_state,
        inputs=[webcam, shared_state],
        outputs=[webcam, shared_state],
    )

    # LOOP 2: Low-Frequency UI Updates (4 times per second)
    # This runs on a timer, reads the shared state, and updates the other UI elements.
    app.load(
        fn=update_ui_from_state,
        inputs=[shared_state],
        outputs=[status_output, audio_alert_output],
    )

if __name__ == "__main__":
    print("Starting Drive Paddy WebRTC Application...")
    app.launch(debug=True)