Testys commited on
Commit
ba6a8ea
Β·
verified Β·
1 Parent(s): 29b0bbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -109
app.py CHANGED
@@ -2,157 +2,136 @@
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
- import soundfile as sf
6
- import os
7
- import yaml
8
  from dotenv import load_dotenv
9
  from threading import Thread
10
 
11
  # --- TTS & AI Imports ---
12
  from parler_tts import ParlerTTSForConditionalGeneration
13
- from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
14
- from streamer import ParlerTTSStreamer # Make sure streamer.py is available
15
 
16
  from src.detection.factory import get_detector
17
  from src.alerting.alert_system import get_alerter
18
 
19
- # --- Load Configuration and Environment Variables ---
20
- # This part is the same as our Streamlit app
 
21
  load_dotenv()
22
- config_path = 'config.yaml'
23
- with open(config_path, 'r') as f:
24
  config = yaml.safe_load(f)
25
- secrets = {
26
- "gemini_api_key": os.getenv("GEMINI_API_KEY"),
27
- }
28
 
29
- # --- Initialize Backend Components ---
30
- print("Initializing detector and alerter...")
 
31
  detector = get_detector(config)
32
  alerter = get_alerter(config, secrets["gemini_api_key"])
33
- print("Initialization complete. Launching UI...")
34
 
35
- # --- Parler-TTS Model Setup (Requires GPU) ---
36
- print("Loading Parler-TTS model. This may take a moment...")
 
37
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
38
  if device == "cpu":
39
- print("\nWARNING: Running Parler-TTS on a CPU will be extremely slow. A GPU is highly recommended.\n")
40
- torch_dtype = torch.float16 if device != "cpu" else torch.float32
41
-
42
 
43
- # Using a smaller, faster model suitable for real-time alerts
44
  repo_id = "parler-tts/parler_tts_mini_v0.1"
45
- model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id, torch_dtype=torch_dtype).to(device)
 
 
 
46
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
47
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
48
- print("Parler-TTS model loaded.")
49
-
50
- # --- Audio Streaming Generator Function ---
51
- def stream_alert_audio(text_prompt):
52
- """
53
- A generator function that yields audio chunks for a given text prompt.
54
- This is the core of the streaming implementation.
55
- """
56
  sampling_rate = model.config.sampling_rate
57
- description = "Jenny is A female speaker with a clear and urgent voice." # Voice prompt for TTS
58
-
59
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
60
- description_ids = tokenizer(description, return_tensors="pt").input_ids.to(device)
61
 
62
- # Setup the streamer
63
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
64
-
65
- generation_kwargs = dict(
66
- input_ids=description_ids,
67
  prompt_input_ids=prompt_ids,
68
  streamer=streamer,
69
  do_sample=True,
70
- temperature=1.0, # Increase for more vocal variety
71
  repetition_penalty=1.2,
72
  )
73
-
74
- # Run generation in a separate thread to not block the UI
75
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
76
-
77
  try:
78
  thread.start()
79
- print(f"Audio stream started for: '{text_prompt}'")
80
- # Yield audio chunks as they become available
81
- for new_audio_chunk in streamer:
82
- yield new_audio_chunk
83
  finally:
84
- # CRITICAL: This block runs after the generator is exhausted (audio finishes)
85
- # We reset the alerter state so that a new alert can be triggered later.
86
- print("Audio stream finished. Resetting alerter state.")
87
  alerter.reset_alert()
88
-
89
- # --- Main Webcam Processing Function ---
 
 
90
  def process_live_frame(frame):
91
- """
92
- Processes each webcam frame, performs drowsiness detection, and
93
- returns a generator for audio streaming when an alert is triggered.
94
- """
95
  if frame is None:
96
- return np.zeros((480, 640, 3), dtype=np.uint8), "Status: Inactive", None
97
-
98
- processed_frame, indicators, _ = detector.process_frame(frame)
99
- drowsiness_level = indicators.get("drowsiness_level", "Awake")
100
- lighting = indicators.get("lighting", "Good")
101
- score = indicators.get("details", {}).get("Score", 0)
102
-
103
- # Build status text
104
- status_text = f"Lighting: {lighting}\n"
105
- if lighting == "Low":
106
- status_text += "Detection paused due to low light."
107
- else:
108
- status_text += f"Status: {drowsiness_level}\nScore: {score:.2f}"
109
-
110
- # --- Alert Trigger Logic ---
111
- audio_output = None
112
- if drowsiness_level != "Awake":
113
- # alerter.trigger_alert() returns the alert TEXT if not on cooldown, otherwise None.
114
- alert_text = alerter.trigger_alert(level=drowsiness_level)
115
- if alert_text:
116
- # If we got text, it means we can start an alert.
117
- # We return the generator function itself. Gradio will handle it.
118
- audio_output = stream_alert_audio(alert_text)
119
-
120
- else:
121
- audio_output = stream_alert_audio(alert_text)
122
-
123
- # On subsequent frames where the user is drowsy, trigger_alert() will return None
124
- # due to the cooldown, preventing a new stream from starting, which is what we want.
125
-
126
- return processed_frame, status_text, audio_output
127
-
128
-
129
- # --- Gradio UI Definition ---
130
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
131
- gr.Markdown("# πŸš— Drive Paddy - Drowsiness Detection (Streaming)")
132
- gr.Markdown("Live drowsiness detection with real-time, streaming voice alerts.")
133
 
134
  with gr.Row():
135
  with gr.Column(scale=2):
136
- webcam_input = gr.Image(sources=["webcam"], streaming=True, label="Live Camera Feed")
 
137
  with gr.Column(scale=1):
138
- processed_output = gr.Image(label="Processed Feed")
139
- status_output = gr.Textbox(label="Live Status", lines=3, interactive=False)
140
-
141
- # --- KEY CHANGE: The Audio component now uses streaming=True ---
142
- audio_alert_output = gr.Audio(
143
- label="Alert System",
144
- autoplay=True,
145
- visible=False, # Hide the player controls
146
- streaming=True
147
- )
148
-
149
- webcam_input.stream(
150
  fn=process_live_frame,
151
- inputs=[webcam_input],
152
- outputs=[processed_output, status_output, audio_alert_output]
153
  )
154
 
155
-
156
- # --- Launch the App ---
157
  if __name__ == "__main__":
158
- app.launch(debug=True)
 
2
  import gradio as gr
3
  import numpy as np
4
  import torch
5
+ import os, yaml, soundfile as sf
 
 
6
  from dotenv import load_dotenv
7
  from threading import Thread
8
 
9
  # --- TTS & AI Imports ---
10
  from parler_tts import ParlerTTSForConditionalGeneration
11
+ from transformers import AutoTokenizer, AutoFeatureExtractor
12
+ from streamer import ParlerTTSStreamer # local file
13
 
14
  from src.detection.factory import get_detector
15
  from src.alerting.alert_system import get_alerter
16
 
17
+ # ──────────────────────────────────────────────────────────
18
+ # CONFIG & BACKEND SET-UP
19
+ # ──────────────────────────────────────────────────────────
20
  load_dotenv()
21
+
22
+ with open("config.yaml", "r") as f:
23
  config = yaml.safe_load(f)
 
 
 
24
 
25
+ secrets = {"gemini_api_key": os.getenv("GEMINI_API_KEY")}
26
+
27
+ print("Initializing detector and alerter …")
28
  detector = get_detector(config)
29
  alerter = get_alerter(config, secrets["gemini_api_key"])
30
+ print("Backend ready.")
31
 
32
+ # ──────────────────────────────────────────────────────────
33
+ # TTS MODEL (Parler-TTS mini)
34
+ # ──────────────────────────────────────────────────────────
35
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
36
  if device == "cpu":
37
+ print("\n⚠️ Running TTS on CPU will be slow; only β€˜Very Drowsy’ alerts will use it.\n")
 
 
38
 
39
+ model_dtype = torch.float16 if device != "cpu" else torch.float32
40
  repo_id = "parler-tts/parler_tts_mini_v0.1"
41
+
42
+ print("Loading Parler-TTS …")
43
+ model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id,
44
+ torch_dtype=model_dtype).to(device)
45
  tokenizer = AutoTokenizer.from_pretrained(repo_id)
46
  feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
47
+ print("TTS loaded.")
48
+
49
+ # ──────────────────────────────────────────────────────────
50
+ # AUDIO STREAMER
51
+ # ──────────────────────────────────────────────────────────
52
+ def stream_alert_audio(text_prompt: str):
53
+ """Yields (sampling_rate, np.ndarray) chunks for Gradio streaming."""
 
54
  sampling_rate = model.config.sampling_rate
55
+ voice_desc = "Jenny is a female speaker with a clear and urgent voice."
56
+
57
  prompt_ids = tokenizer(text_prompt, return_tensors="pt").input_ids.to(device)
58
+ desc_ids = tokenizer(voice_desc, return_tensors="pt").input_ids.to(device)
59
 
 
60
  streamer = ParlerTTSStreamer(model, device, play_steps=int(sampling_rate * 2.0))
61
+
62
+ gen_kwargs = dict(
63
+ input_ids=desc_ids,
64
  prompt_input_ids=prompt_ids,
65
  streamer=streamer,
66
  do_sample=True,
67
+ temperature=1.0,
68
  repetition_penalty=1.2,
69
  )
70
+
71
+ thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
72
+
 
73
  try:
74
  thread.start()
75
+ for chunk in streamer:
76
+ yield (sampling_rate, chunk)
 
 
77
  finally:
78
+ thread.join(timeout=0.1)
 
 
79
  alerter.reset_alert()
80
+
81
+ # ──────────────────────────────────────────────────────────
82
+ # FRAME PROCESSOR
83
+ # ──────────────────────────────────────────────────────────
84
  def process_live_frame(frame):
 
 
 
 
85
  if frame is None:
86
+ return np.zeros((480, 640, 3), np.uint8), "Status: Inactive", None
87
+
88
+ processed, indicators, _ = detector.process_frame(frame)
89
+ level = indicators.get("drowsiness_level", "Awake")
90
+ lighting = indicators.get("lighting", "Good")
91
+ score = indicators.get("details", {}).get("Score", 0)
92
+
93
+ status_txt = f"Lighting: {lighting}\n"
94
+ status_txt += ("Detection paused due to low light."
95
+ if lighting == "Low"
96
+ else f"Status: {level}\nScore: {score:.2f}")
97
+
98
+ audio_out = None
99
+ if level != "Awake" and lighting != "Low":
100
+ payload = alerter.trigger_alert(level=level)
101
+ if payload:
102
+ # Static file path β†’ bytes, Dynamic Gemini path β†’ str
103
+ if isinstance(payload, bytes):
104
+ # Return raw bytes (Gradio accepts bytes for .wav / .mp3)
105
+ audio_out = payload
106
+ elif isinstance(payload, str):
107
+ audio_out = stream_alert_audio(payload)
108
+
109
+ return processed, status_txt, audio_out
110
+
111
+ # ──────────────────────────────────────────────────────────
112
+ # GRADIO UI
113
+ # ──────────────────────────────────────────────────────────
 
 
 
 
 
 
114
  with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as app:
115
+ gr.Markdown("# πŸš— Drive Paddy – Drowsiness Detection")
116
+ gr.Markdown("Live detection with real-time voice alerts.")
117
 
118
  with gr.Row():
119
  with gr.Column(scale=2):
120
+ webcam = gr.Image(sources=["webcam"], streaming=True,
121
+ label="Live Camera Feed")
122
  with gr.Column(scale=1):
123
+ processed_img = gr.Image(label="Processed Feed")
124
+ status_box = gr.Textbox(label="Live Status", lines=3, interactive=False)
125
+ alert_audio = gr.Audio(label="Alert",
126
+ autoplay=True,
127
+ streaming=True,
128
+ height=40)
129
+
130
+ webcam.stream(
 
 
 
 
131
  fn=process_live_frame,
132
+ inputs=webcam,
133
+ outputs=[processed_img, status_box, alert_audio],
134
  )
135
 
 
 
136
  if __name__ == "__main__":
137
+ app.launch(debug=True)