File size: 6,255 Bytes
5101617
 
 
95b307f
ba6a8ea
5101617
95b307f
 
 
 
ba6a8ea
 
5101617
 
 
 
ba6a8ea
 
 
5101617
ba6a8ea
 
5101617
 
ba6a8ea
 
 
5101617
 
ba6a8ea
817a521
ba6a8ea
 
 
95b307f
 
ba6a8ea
3a175cd
ba6a8ea
95b307f
ba6a8ea
 
 
 
95b307f
 
ba6a8ea
 
 
 
 
 
 
95b307f
ba6a8ea
 
95b307f
ba6a8ea
3a175cd
95b307f
ba6a8ea
 
 
3a175cd
 
 
ba6a8ea
3a175cd
 
ba6a8ea
 
 
5101617
95b307f
ba6a8ea
 
95b307f
ba6a8ea
95b307f
ba6a8ea
 
 
 
3a175cd
 
ba6a8ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95b307f
ba6a8ea
 
5101617
8977de8
95b307f
ba6a8ea
 
95b307f
ba6a8ea
 
 
 
 
 
 
 
3a175cd
ba6a8ea
 
8977de8
5101617
8977de8
ba6a8ea
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
# app_gradio.py
import gradio as gr
import numpy as np
import torch
import os, yaml, soundfile as sf
from dotenv import load_dotenv
from threading import Thread

# --- TTS & AI Imports ---
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor
from streamer import ParlerTTSStreamer                     # local file

from src.detection.factory import get_detector
from src.alerting.alert_system import get_alerter

# ──────────────────────────────────────────────────────────
# CONFIG & BACKEND SET-UP
# ──────────────────────────────────────────────────────────
load_dotenv()

with open("config.yaml", "r") as f:
    config = yaml.safe_load(f)

secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}

print("Initializing detector and alerter …")
detector = get_detector(config)
alerter = get_alerter(config, secrets["gemini_api_key"])
print("Backend ready.")

# ──────────────────────────────────────────────────────────
# TTS MODEL (Parler-TTS mini)
# ──────────────────────────────────────────────────────────
device = "cuda:0" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("\n⚠️  Running TTS on CPU will be slow; only β€˜Very Drowsy’ alerts will use it.\n")

model_dtype = torch.float16 if device != "cpu" else torch.float32
repo_id = "parler-tts/parler_tts_mini_v0.1"

print("Loading Parler-TTS …")
model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id,
                                                         torch_dtype=model_dtype).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
print("TTS loaded.")

# ──────────────────────────────────────────────────────────
# AUDIO STREAMER
# ──────────────────────────────────────────────────────────
def stream_alert_audio(text_prompt: str):
    """Yields (sampling_rate, np.ndarray) chunks for Gradio streaming."""
    sampling_rate = model.config.sampling_rate
    voice_desc = "Jenny is a female speaker with a clear and urgent voice."

    prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
    desc_ids   = tokenizer(voice_desc, return_tensors="pt").input_ids.to(device)

    streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))

    gen_kwargs = dict(
        input_ids=desc_ids,
        prompt_input_ids=prompt_ids,
        streamer=streamer,
        do_sample=True,
        temperature=1.0,
        repetition_penalty=1.2,
    )

    thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)

    try:
        thread.start()
        for chunk in streamer:
            yield (sampling_rate, chunk)
    finally:
        thread.join(timeout=0.1)
        alerter.reset_alert()

# ──────────────────────────────────────────────────────────
# FRAME PROCESSOR
# ──────────────────────────────────────────────────────────
def process_live_frame(frame):
    if frame is None:
        return np.zeros((480, 640, 3), np.uint8), "Status: Inactive", None

    processed, indicators, _ = detector.process_frame(frame)
    level      = indicators.get("drowsiness_level", "Awake")
    lighting   = indicators.get("lighting", "Good")
    score      = indicators.get("details", {}).get("Score", 0)

    status_txt = f"Lighting: {lighting}\n"
    status_txt += ("Detection paused due to low light."
                   if lighting == "Low"
                   else f"Status: {level}\nScore: {score:.2f}")

    audio_out = None
    if level != "Awake" and lighting != "Low":
        payload = alerter.trigger_alert(level=level)
        if payload:
            # Static file path β†’ bytes, Dynamic Gemini path β†’ str
            if isinstance(payload, bytes):
                # Return raw bytes (Gradio accepts bytes for .wav / .mp3)
                audio_out = payload
            elif isinstance(payload, str):
                audio_out = stream_alert_audio(payload)

    return processed, status_txt, audio_out

# ──────────────────────────────────────────────────────────
# GRADIO UI
# ──────────────────────────────────────────────────────────
with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
    gr.Markdown("# πŸš— Drive Paddy – Drowsiness Detection")
    gr.Markdown("Live detection with real-time voice alerts.")

    with gr.Row():
        with gr.Column(scale=2):
            webcam = gr.Image(sources=["webcam"], streaming=True,
                              label="Live Camera Feed")
        with gr.Column(scale=1):
            processed_img = gr.Image(label="Processed Feed")
            status_box    = gr.Textbox(label="Live Status", lines=3, interactive=False)
            alert_audio   = gr.Audio(label="Alert",
                                     autoplay=True,
                                     streaming=True,
                                     height=40)

    webcam.stream(
        fn=process_live_frame,
        inputs=webcam,
        outputs=[processed_img, status_box, alert_audio],
    )

if __name__ == "__main__":
    app.launch(debug=True)